From b0bd644a9074e62f0d7f48017f20449b0466d666 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 27 Sep 2024 22:24:08 +0800 Subject: [PATCH 0001/1191] Add norm-sort stats. --- icefall/diagnostics.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 37872f2331..7c3c29f5fc 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -63,12 +63,23 @@ def get_tensor_stats( "rms" -> square before summing, we'll take sqrt later "value" -> just sum x itself "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing + + "rms-sort" -> this is a bit different than the others, it's based on computing the + rms over the specified dim and returning percentiles of the result (11 of them). Returns: stats: a Tensor of shape (x.shape[dim],). count: an integer saying how many items were counted in each element of stats. """ + if stats_type == "rms-sort": + rms = (x ** 2).mean(dim=dim) + rms = rms.flatten() + rms = rms.sort()[0] + rms = rms[ (torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1) ] + count = 1.0 + return rms, count + count = x.numel() // x.shape[dim] if stats_type == "eigs": @@ -164,7 +175,9 @@ def accumulate(self, x, class_name: Optional[str] = None): for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: - stats_types = ["abs", "max", "min", "positive", "value", "rms"] + # rms-sort is different from the others, it's based on summing over just this + # dim, then sorting and returning the percentiles. + stats_types = ["abs", "max", "min", "positive", "value", "rms", "rms-sort"] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: From d703ce44dc652d32518b37864f3eea184f474c97 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 27 Sep 2024 22:43:10 +0800 Subject: [PATCH 0002/1191] Bug fix to rms --- icefall/diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 7c3c29f5fc..2cd350c07d 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -73,7 +73,7 @@ def get_tensor_stats( """ if stats_type == "rms-sort": - rms = (x ** 2).mean(dim=dim) + rms = (x ** 2).mean(dim=dim).sqrt() rms = rms.flatten() rms = rms.sort()[0] rms = rms[ (torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1) ] From 3293a7babdcc1a1a01e98a35b8b281dc3c06f7a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Oct 2024 12:20:21 +0800 Subject: [PATCH 0003/1191] Add dev-clean and dev-other decoding. --- egs/librispeech/ASR/zipformer/decode.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index df2d555a09..1a81141f96 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1017,12 +1017,16 @@ def main(): test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( From 97a742164ad20aeea55aeb1e03f004d159a28e8d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Oct 2024 12:04:41 +0800 Subject: [PATCH 0004/1191] working on refactor of optim.py --- egs/librispeech/ASR/zipformer/optim.py | 154 ++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6f5180e29e..3f19199aa8 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -121,6 +121,141 @@ def batched_params(self, param_group, group_params_names): p.copy_(stacked_params[i]) + +def basic_step(group, p, state, grad): + # computes the moving-average squared gradient and divides by + # it; includes adaptive epsilon. (?) + # takes a step in the gradient direction and + # returns it (times the lr). + delta = param_rms * basic_step(group, p, state, grad) + # on some batches, update size of parameter. + return delta + +def basic_step(group, p, state, grad): + # computes basic Adam update using beta2 only. + lr = group["lr"] + if p.numel() == p.shape[0]: + lr = lr * group["scalar_lr_scale"] + beta2 = group["betas"][1] + eps = group["eps"] + # p shape: (batch_size,) or (batch_size, 1, [1,..]) + try: + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) or (batch_size, 1, [1,..]) + except KeyError: + exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + return -lr * grad / denom + + +def scaling_step(group, p, state, grad): + delta = basic_step(group, p, state, grad) + if p.numel() == p.shape[0]: + return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) + + step = state["step"] + size_update_period = group["size_update_period"] + + try: + param_rms = state["param_rms"] + scale_grads = state["scale_grads"] + scale_exp_avg_sq = state["scale_exp_avg_sq"] + except KeyError: + # we know p.ndim > 1 because we'd have returned above if not, so don't worry + # about the speial case of dim=[] that pytorch treats inconsistently. + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = param_rms.to(torch.float) + scale_exp_avg_sq = torch.zeros_like(param_rms) + scale_grads = torch.zeros(size_update_period, *param_rms.shape, + dtype=torch.float, device=p.device) + state["param_rms"] = param_rms + state["scale_exp_avg_sq"] = scale_exp_avg_sq + state["scale_grads"] = scale_grads + + + # on every step, update the gradient w.r.t. the scale of the parameter, we + # store these as a batch and periodically update the size (for speed only, to + # avoid too many operations). + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + + # periodically recompute the value of param_rms. + if step % size_update_period == size_update_period - 1: + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + + # scale the step size by param_rms. This is the most important "scaling" part of + # ScaledAdam + delta = delta * param_rms + + if step % size_update_period == size_update_period - 1 and step > 0: + # This block updates the size of parameter by adding a step ("delta") value in + # the direction of either shrinking or growing it. + beta2 = group["betas"][1] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + batch_size = p.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta.add_(p * scale_step) + + return delta + + +def momentum_step(group, p, state, grad): + delta = scaling_step(group, p, state, grad) + beta1 = group["betas"][0] + try: + stored_delta = state["delta"] + except KeyError: + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + stored_delta.mul_(beta1) + stored_delta.add_(delta, alpha=(1-beta1)) + # we don't bother doing the "bias correction" part of Adam for beta1 because this is just + # an edge effect that affects the first 10 or so batches; and the effect of not doing it + # is just to do a slower update for the first few batches, which will help stability. + return stored_delta + + + + class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -147,8 +282,7 @@ class ScaledAdam(BatchedOptimizer): by this quantity. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor and scalar parameters of the mode.. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the scaling factor on the learning rate of p_scale. @@ -355,8 +489,22 @@ def step(self, closure=None): # State initialization if len(state) == 0: self._init_state(group, p, state) + # TODO: remove this. + + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + grad = (p.grad if clipping_scale == 1.0 else p.grad * clipping_scale) + p += momentum_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 + - self._step_one_batch(group, p, state, clipping_scale) + #self._step_one_batch(group, p, state, clipping_scale) return loss From 7d619a1e4749fb2f1c15fd273296bd6e71536ae1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Oct 2024 12:13:23 +0800 Subject: [PATCH 0005/1191] Bug fixes s --- egs/librispeech/ASR/zipformer/optim.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3f19199aa8..481165fb03 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -143,6 +143,7 @@ def basic_step(group, p, state, grad): exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) or (batch_size, 1, [1,..]) except KeyError: exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["exp_avg_sq"] = exp_avg_sq exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -174,8 +175,8 @@ def scaling_step(group, p, state, grad): scale_grads = torch.zeros(size_update_period, *param_rms.shape, dtype=torch.float, device=p.device) state["param_rms"] = param_rms - state["scale_exp_avg_sq"] = scale_exp_avg_sq state["scale_grads"] = scale_grads + state["scale_exp_avg_sq"] = scale_exp_avg_sq # on every step, update the gradient w.r.t. the scale of the parameter, we @@ -486,10 +487,6 @@ def step(self, closure=None): raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - # TODO: remove this. try: From 1f3aa86cfca84e5b3f6a5dae74ccace73edb6388 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Oct 2024 13:11:12 +0800 Subject: [PATCH 0006/1191] Code cleanup, finish refactoring optimizer --- egs/librispeech/ASR/zipformer/optim.py | 255 ++----------------------- 1 file changed, 18 insertions(+), 237 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 481165fb03..2c5a6ce4e6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -123,16 +123,7 @@ def batched_params(self, param_group, group_params_names): def basic_step(group, p, state, grad): - # computes the moving-average squared gradient and divides by - # it; includes adaptive epsilon. (?) - # takes a step in the gradient direction and - # returns it (times the lr). - delta = param_rms * basic_step(group, p, state, grad) - # on some batches, update size of parameter. - return delta - -def basic_step(group, p, state, grad): - # computes basic Adam update using beta2 only. + # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. lr = group["lr"] if p.numel() == p.shape[0]: lr = lr * group["scalar_lr_scale"] @@ -150,7 +141,11 @@ def basic_step(group, p, state, grad): # bias_correction2 is like in Adam. # slower update at the start will help stability anyway. bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = exp_avg_sq.sqrt().add_(eps) + return -lr * grad / denom @@ -192,16 +187,17 @@ def scaling_step(group, p, state, grad): (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() ) + param_min_rms = group["param_min_rms"] + # scale the step size by param_rms. This is the most important "scaling" part of # ScaledAdam - delta = delta * param_rms + delta *= param_rms.clamp(min=param_min_rms) if step % size_update_period == size_update_period - 1 and step > 0: # This block updates the size of parameter by adding a step ("delta") value in # the direction of either shrinking or growing it. beta2 = group["betas"][1] size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] param_max_rms = group["param_max_rms"] eps = group["eps"] batch_size = p.shape[0] @@ -228,6 +224,10 @@ def scaling_step(group, p, state, grad): # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) + # The following may help prevent instability: don't allow the scale step to be too large in + # either direction. + scale_step.clamp_(min=-0.1, max=0.1) + # and ensure the parameter rms after update never exceeds param_max_rms. # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the @@ -256,7 +256,6 @@ def momentum_step(group, p, state, grad): - class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -495,60 +494,18 @@ def step(self, closure=None): state["step"] = 0 cur_step = 0 - grad = (p.grad if clipping_scale == 1.0 else p.grad * clipping_scale) + grad = (p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)) p += momentum_step(group, p.detach(), state, grad) - state["step"] = cur_step + 1 + if p.numel() == p.shape[0]: # scalar parameter + scalar_max = group["scalar_max"] + p.clamp_(min=-scalar_max, max=scalar_max) + state["step"] = cur_step + 1 - #self._step_one_batch(group, p, state, clipping_scale) return loss - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: @@ -726,182 +683,6 @@ def _show_gradient_dominating_parameter( f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad *= clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) class LRScheduler(object): From a204c5ea09dc1c638f8f45a946cc56fa962bd51e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Dec 2024 20:51:02 +0800 Subject: [PATCH 0007/1191] Take optim.py with better debug output. --- egs/librispeech/ASR/zipformer/optim.py | 43 ++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3ffc15d97b..a06fed7544 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -586,7 +586,7 @@ def _get_clipping_scale( ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.warn( + logging.warning( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) @@ -601,8 +601,8 @@ def _get_clipping_scale( ans = 0.0 if ans < 1.0: first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( + if ans < 0.5: + logging.warning( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) if self.show_dominant_parameters: @@ -610,6 +610,7 @@ def _get_clipping_scale( self._show_gradient_dominating_parameter( tuples, tot_sumsq, group["scalar_lr_scale"] ) + self._show_param_with_unusual_grad(tuples) if ans == 0.0: for (p, state, param_names) in tuples: @@ -617,6 +618,36 @@ def _get_clipping_scale( return ans + def _show_param_with_unusual_grad( + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + ): + """ + Print information about parameter which has the largest ratio of grad-on-this-batch + divided by normal grad size. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + largest_ratio = 0.0 + largest_name = "" + for (p, state, batch_param_names) in tuples: + dims = list(range(1, p.ndim)) + grad_ratio = ((p.grad ** 2).mean(dim=dims) / + state["exp_avg_sq"].mean(dim=dims)) + max_grad_ratio, max_index = grad_ratio.to('cpu').max(dim=0) + if max_grad_ratio.item() > largest_ratio: + largest_ratio = max_grad_ratio.item() + largest_name = batch_param_names[max_index.item()] + logging.warning(f"Parameter with most larger-than-usual grad is {largest_name}, with ratio (cur_grad / normal_grad) of " + f"{largest_ratio ** 0.5}") + + + + def _show_gradient_dominating_parameter( self, tuples: List[Tuple[Tensor, dict, List[str]]], @@ -674,7 +705,7 @@ def _show_gradient_dominating_parameter( dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.warn( + logging.warning( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -779,7 +810,7 @@ def _set_lrs(self): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.warn( + logging.warning( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -1110,7 +1141,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From e967862b44024f10f4ecee71e61862a95974f058 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Dec 2024 11:49:16 +0800 Subject: [PATCH 0008/1191] warmup_start=0.1 --- egs/librispeech/ASR/zipformer/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c074c32ec7..c77960e05d 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1335,7 +1335,8 @@ def run(rank, world_size, args): clipping_scale=2.0, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, + warmup_start=0.1) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From bba1d02e3d3aedb4f92422d5ddebb5ea149b00f7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Dec 2024 14:27:34 +0800 Subject: [PATCH 0009/1191] doubled warmup_batches to 1000 --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c77960e05d..b2366c8076 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1336,7 +1336,7 @@ def run(rank, world_size, args): ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, - warmup_start=0.1) + warmup_start=0.1, warmup_batches=1000) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 6d4c6275a5a6d7b82f275d305897bd36ab0d2078 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 16 Dec 2024 12:03:40 +0800 Subject: [PATCH 0010/1191] Print more modules for larger-than-usual grads --- egs/librispeech/ASR/zipformer/optim.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a06fed7544..949626b8e6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -634,16 +634,22 @@ def _show_param_with_unusual_grad( """ largest_ratio = 0.0 largest_name = "" + ratios_names = [ ] for (p, state, batch_param_names) in tuples: dims = list(range(1, p.ndim)) - grad_ratio = ((p.grad ** 2).mean(dim=dims) / - state["exp_avg_sq"].mean(dim=dims)) - max_grad_ratio, max_index = grad_ratio.to('cpu').max(dim=0) - if max_grad_ratio.item() > largest_ratio: - largest_ratio = max_grad_ratio.item() - largest_name = batch_param_names[max_index.item()] - logging.warning(f"Parameter with most larger-than-usual grad is {largest_name}, with ratio (cur_grad / normal_grad) of " - f"{largest_ratio ** 0.5}") + def mean(x): + # workaround for bad interface of torch's "mean" for when dims is the empty list. + if len(dims) > 0: + return x.mean(dim=dims) + else: + return x + grad_ratio = (mean(p.grad ** 2) / state["exp_avg_sq"].mean(dim=dims)).sqrt() + ratios_names += zip(grad_ratio.to('cpu').tolist(), batch_param_names) + + ratios_names = sorted(ratios_names, reverse=True) + ratios_names = ratios_names[:10] + + logging.warning(f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}") From d3784295bfa87e8ce9e8b13664e7dcd4826f760a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Jan 2025 20:29:59 +0800 Subject: [PATCH 0011/1191] Make zipformer more deterministic by removing various dropouts/masks. --- egs/librispeech/ASR/zipformer/train.py | 9 - egs/librispeech/ASR/zipformer/zipformer.py | 262 +-------------------- 2 files changed, 11 insertions(+), 260 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index b2366c8076..fb92c20c16 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -191,14 +191,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Positional-encoding embedding dimension", ) - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - parser.add_argument( "--cnn-module-kernel", type=str, @@ -646,7 +638,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: downsampling_factor=_to_int_tuple(params.downsampling_factor), num_encoder_layers=_to_int_tuple(params.num_encoder_layers), encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), query_head_dim=_to_int_tuple(params.query_head_dim), pos_head_dim=_to_int_tuple(params.pos_head_dim), value_head_dim=_to_int_tuple(params.value_head_dim), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae01297..79af9ae92b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -65,9 +65,6 @@ class Zipformer2(EncoderInterface): encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per encoder stack. num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). query_head_dim (int or Tuple[int]): dimension of query and key per attention head: per stack, if a tuple.. pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per @@ -103,7 +100,6 @@ def __init__( downsampling_factor: Tuple[int] = (2, 4), encoder_dim: Union[int, Tuple[int]] = 384, num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, query_head_dim: Union[int, Tuple[int]] = 24, pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, @@ -136,9 +132,6 @@ def _to_tuple(x): self.output_downsampling_factor = output_downsampling_factor # int self.downsampling_factor = downsampling_factor # tuple self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( - encoder_unmasked_dim - ) # tuple num_encoder_layers = _to_tuple(num_encoder_layers) self.num_encoder_layers = num_encoder_layers self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) @@ -152,9 +145,6 @@ def _to_tuple(x): self.chunk_size = chunk_size self.left_context_frames = left_context_frames - for u, d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder encoders = [] @@ -180,9 +170,6 @@ def _to_tuple(x): num_encoder_layers[i], pos_dim=pos_dim, dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), ) if downsampling_factor[i] != 1: @@ -205,68 +192,6 @@ def _to_tuple(x): causal=causal, ) - def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0, ( - self.encoder_dim[0], - _encoder_dims0, - ) - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = ( - torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob - ).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and( - mask1, - ( - torch.rand(1, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype), - ) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) - - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones( - 1, batch_size, channels, dtype=x.dtype, device=x.device - ) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] - - feature_masks.append(feature_mask) - - return feature_masks - def get_chunk_info(self) -> Tuple[int, int]: """ Returns chunk_size and left_context_chunks. @@ -318,10 +243,6 @@ def forward( of frames in `embeddings` before padding. """ outputs = [] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) chunk_size, left_context_chunks = self.get_chunk_info() @@ -338,7 +259,6 @@ def forward( x = module( x, chunk_size=chunk_size, - feature_mask=feature_masks[i], src_key_padding_mask=( None if src_key_padding_mask is None @@ -575,48 +495,17 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - conv_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - const_attention_rate: FloatLike = ScheduledFloat( - (0.0, 0.25), (4000.0, 0.025), default=0 - ), - ff2_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - ff3_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - bypass_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.5), (4000.0, 0.02), default=0 - ), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. self.bypass = BypassModule( - embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + embed_dim, ) # bypass_mid is bypass used in the middle of the layer. self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, @@ -715,31 +604,6 @@ def __init__( max_abs=4.0, ) - def get_sequence_dropout_mask( - self, x: Tensor, dropout_rate: float - ) -> Optional[Tensor]: - if ( - dropout_rate == 0.0 - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - def forward( self, src: Tensor, @@ -754,8 +618,6 @@ def forward( src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None. @@ -767,14 +629,6 @@ def forward( """ src_orig = src - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting() or torch.jit.is_tracing(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = ( - float(self.attention_skip_rate) if self.training else 0.0 - ) - # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( src, @@ -785,88 +639,25 @@ def forward( src = src + self.feed_forward1(src) - self_attn_dropout_mask = self.get_sequence_dropout_mask( - src, attention_skip_rate - ) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to( - selected_attn_weights.dtype - ) - selected_attn_weights = selected_attn_weights * ( - 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) - ) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + src = src + self.balancer_na(self.nonlin_attention(src, attn_weights[0:1])) - src = src + ( - na if self_attn_dropout_mask is None else na * self_attn_dropout_mask - ) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) + src = src + self.self_attn1(src, attn_weights) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.conv_module1( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - conv_skip_rate, + src = src + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask ) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate - ) + src = src + self.balancer_ff2(self.feed_forward2(src)) # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) - self_attn = self.self_attn2(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.conv_module2( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - conv_skip_rate, - ) + src = src + self.self_attn2(src, attn_weights) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + src = src + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask ) + src = src + self.balancer_ff3(self.feed_forward3(src)) src = self.balancer1(src) src = self.norm(src) @@ -1017,14 +808,10 @@ def __init__( num_layers: int, pos_dim: int, dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.15, length_factor=1.0 + pos_dim, dropout_rate=0.0, length_factor=1.0 ) self.layers = nn.ModuleList( @@ -1032,24 +819,10 @@ def __init__( ) self.num_layers = num_layers - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat( - (cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0, - ) - cur_begin = cur_end - def forward( self, src: Tensor, chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: @@ -1058,8 +831,6 @@ def forward( Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None. @@ -1071,9 +842,6 @@ def forward( pos_emb = self.encoder_pos(src) output = src - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask - for i, mod in enumerate(self.layers): output = mod( output, @@ -1083,9 +851,6 @@ def forward( src_key_padding_mask=src_key_padding_mask, ) - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask - return output def streaming_forward( @@ -1240,7 +1005,6 @@ def forward( self, src: Tensor, chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: @@ -1248,8 +1012,6 @@ def forward( Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None. @@ -1267,7 +1029,6 @@ def forward( src = self.encoder( src, chunk_size=chunk_size // ds, - feature_mask=feature_mask, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, ) @@ -2432,7 +2193,6 @@ def _test_zipformer_main(causal: bool = False): c = Zipformer2( encoder_dim=(64, 96), - encoder_unmasked_dim=(48, 64), num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), From c97aaad33b2bf1d6b7e98b908e77973683c44c80 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Jan 2025 21:49:58 +0800 Subject: [PATCH 0012/1191] Introduce randomization in zipformer encoder layers. --- egs/librispeech/ASR/zipformer/zipformer.py | 54 +++++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 79af9ae92b..ca9e36eb4d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -466,7 +466,6 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: def _balancer_schedule(min_prob: float): return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - class Zipformer2EncoderLayer(nn.Module): """ Args: @@ -482,7 +481,6 @@ class Zipformer2EncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ - def __init__( self, embed_dim: int, @@ -495,10 +493,12 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.0), (18000.0, 0.0), (40000.0, 2.0)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim + self.randomize_scale = copy.deepcopy(randomize_scale) # self.bypass implements layer skipping as well as bypass; see its default values. self.bypass = BypassModule( embed_dim, @@ -604,6 +604,7 @@ def __init__( max_abs=4.0, ) + def forward( self, src: Tensor, @@ -611,6 +612,53 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + randomize: bool = False, # do the invertibility-encouraging randomization if True. + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + randomize: if true use a form of randomization/dropout that encourages invertibility. + + Returns: + A tensor which has the same shape as src + """ + ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) + if not (randomize and self.training): + return ans + scale = float(self.randomize_scale) + if scale == 0.0: + return ans + + (seq_len, batch_size, emb_dim) = src.shape + t = torch.empty(batch_size, 1).uniform_(0.1, 2.0).clamp_(max=1.0) + # t is random from 0.1 to 1, many elements exactly 1. + + xt = src + (ans - src) * t + ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) + x0 = xt - (ans_t - xt) * t + # x0 is a reconstruction of src based on the assumption that there are + # straight non-crossing trajectories. + # If everything is nicely invertible, x0 - src will be zero and "rand" will be + # zero. + rand = torch.randn_like(src) * (x0 - src) + return ans + rand + + + def forward_internal( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Pass the input through the encoder layer. @@ -842,6 +890,7 @@ def forward( pos_emb = self.encoder_pos(src) output = src + invertible_layer = random.randint(0, len(self.layers)) for i, mod in enumerate(self.layers): output = mod( output, @@ -849,6 +898,7 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, + randomize=(i == invertible_layer), ) return output From 4300af75faa4252ac1500d43d174b20043ad90ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Jan 2025 21:52:13 +0800 Subject: [PATCH 0013/1191] Introduce penalty on non-invertibility. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ca9e36eb4d..4589d2e0c7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -493,7 +493,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.0), (18000.0, 0.0), (40000.0, 2.0)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.1), (18000.0, 0.1), (40000.0, 2.0)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -634,8 +634,8 @@ def forward( if not (randomize and self.training): return ans scale = float(self.randomize_scale) - if scale == 0.0: - return ans + #if scale == 0.0: + # return ans (seq_len, batch_size, emb_dim) = src.shape t = torch.empty(batch_size, 1).uniform_(0.1, 2.0).clamp_(max=1.0) From b4099001c998ee1b1ef896453644ebad4ab0cc06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Jan 2025 22:02:19 +0800 Subject: [PATCH 0014/1191] Device related fix --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4589d2e0c7..ebd85b3b1c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -638,7 +638,7 @@ def forward( # return ans (seq_len, batch_size, emb_dim) = src.shape - t = torch.empty(batch_size, 1).uniform_(0.1, 2.0).clamp_(max=1.0) + t = torch.empty(batch_size, 1, device=src.device).uniform_(0.1, 2.0).clamp_(max=1.0) # t is random from 0.1 to 1, many elements exactly 1. xt = src + (ans - src) * t From 0e93358ab1b48408a93b779bab68f9d9950309e7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 10:01:52 +0800 Subject: [PATCH 0015/1191] Change scaling as it varies with t, to get more even amounts of noise. --- egs/librispeech/ASR/zipformer/zipformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ebd85b3b1c..ef01b691a4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -493,7 +493,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.1), (18000.0, 0.1), (40000.0, 2.0)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.1), (18000.0, 0.1), (40000.0, 1.0)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -644,11 +644,11 @@ def forward( xt = src + (ans - src) * t ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) x0 = xt - (ans_t - xt) * t - # x0 is a reconstruction of src based on the assumption that there are - # straight non-crossing trajectories. - # If everything is nicely invertible, x0 - src will be zero and "rand" will be - # zero. - rand = torch.randn_like(src) * (x0 - src) + # If everything is nicely invertible, the trajectories at t=0 and t will be the same + # so (ans - ans_t) will be zero and "rand" will be zero. + # we divide by "t" to get more even loss-function values. can interpret this as + # a penalty on difference of 2nd derivatives w.r.t. t. + rand = torch.randn_like(src) * (ans - ans_t) / t return ans + rand From f40afcafab007d3c72d326ef9be8f34b23f202b3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 10:16:39 +0800 Subject: [PATCH 0016/1191] Turn off dropout, and oom check --- egs/librispeech/ASR/zipformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index fb92c20c16..c3763a6282 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -645,7 +645,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + dropout=0.0, warmup_batches=4000.0, causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), @@ -1420,7 +1420,7 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if not params.print_diagnostics and False: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, From adabb8429c152c46b2c18267f87b896d96729f09 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 10:34:37 +0800 Subject: [PATCH 0017/1191] Bug fix to random term --- egs/librispeech/ASR/zipformer/zipformer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ef01b691a4..fa888ea5c3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -641,14 +641,11 @@ def forward( t = torch.empty(batch_size, 1, device=src.device).uniform_(0.1, 2.0).clamp_(max=1.0) # t is random from 0.1 to 1, many elements exactly 1. - xt = src + (ans - src) * t + v0 = ans - src + xt = src + v0 * t ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - x0 = xt - (ans_t - xt) * t - # If everything is nicely invertible, the trajectories at t=0 and t will be the same - # so (ans - ans_t) will be zero and "rand" will be zero. - # we divide by "t" to get more even loss-function values. can interpret this as - # a penalty on difference of 2nd derivatives w.r.t. t. - rand = torch.randn_like(src) * (ans - ans_t) / t + vt = ans_t - xt + rand = torch.randn_like(src) * (vt - v0) / t return ans + rand From cf30780cf865c236ee2add89f35b037fb15d0dd5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 15:24:31 +0800 Subject: [PATCH 0018/1191] Implement invertible upsampling and downsampling. There is now no bypass at zipformer level. --- egs/librispeech/ASR/zipformer/scaling.py | 33 ++ egs/librispeech/ASR/zipformer/zipformer.py | 358 +++++++-------------- 2 files changed, 151 insertions(+), 240 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c29316..d8d6764161 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,6 +565,39 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans +def OrthogonalLinearDownsampling(num_channels: int): + # returns a parameterized nn.Linear that stays orthogonal, with a special initialization + # that is suitable to use when downsampling; we reshape then multiply by this matrix. + assert num_channels % 2 == 0 + ans = nn.Linear(num_channels, num_channels, bias=False) + inv_sqrt2 = 2 ** -0.5 + N = num_channels // 2 + eye = inv_sqrt2 * torch.eye(N) + # four blocks: (1/sqrt(2)) (1, 1; -1, 1) + with torch.no_grad(): + ans.weight[:N, :N] = eye + ans.weight[:N, N:] = eye + ans.weight[N:, :N] = -eye + ans.weight[N:, N:] = eye + return torch.nn.utils.parametrizations.orthogonal(ans) + +def OrthogonalLinearUpsampling(num_channels: int): + # returns a parameterized nn.Linear that stays orthogonal, with a special initialization + # that is suitable to use when downsampling; we multiply by this matrix then reshape. + assert num_channels % 2 == 0 + ans = nn.Linear(num_channels, num_channels, bias=False) + inv_sqrt2 = 2 ** -0.5 + N = num_channels // 2 + eye = inv_sqrt2 * torch.eye(N) + # four blocks: (1/sqrt(2)) (1, -1; 1, 1) + with torch.no_grad(): + ans.weight[:N, :N] = eye + ans.weight[:N, N:] = -eye + ans.weight[N:, :N] = eye + ans.weight[N:, N:] = eye + return torch.nn.utils.parametrizations.orthogonal(ans) + + class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fa888ea5c3..bb7bfe6e6c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -25,13 +25,11 @@ import torch from encoder_interface import EncoderInterface -from scaling import ( +from scaling2 import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. -) -from scaling import ( + OrthogonalLinearUpsampling, + OrthogonalLinearDownsampling, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( ActivationDropoutAndLinear, Balancer, BiasNorm, @@ -130,6 +128,7 @@ def _to_tuple(x): return x self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple num_encoder_layers = _to_tuple(num_encoder_layers) @@ -145,11 +144,29 @@ def _to_tuple(x): self.chunk_size = chunk_size self.left_context_frames = left_context_frames - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + # each one will be Zipformer2Encoder or InvertibleDownsample or InvertibleUpsample encoders = [] num_encoders = len(downsampling_factor) + cur_downsample = 1 + input_dim = encoder_dim[0] + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + def set_downsample_factor(cur_downsample, ds): + while cur_downsample < ds: + # need to downsample + encoders.append(InvertibleDownsample(input_dim * cur_downsample)) + cur_downsample *= 2 + while cur_downsample > ds: + encoders.append(InvertibleUpsample(input_dim * cur_downsample)) + cur_downsample //= 2 + return cur_downsample + for i in range(num_encoders): + cur_downsample = set_downsample_factor(cur_downsample, downsampling_factor[i]) + encoder_layer = Zipformer2EncoderLayer( embed_dim=encoder_dim[i], pos_dim=pos_dim, @@ -171,26 +188,14 @@ def _to_tuple(x): pos_dim=pos_dim, dropout=dropout, ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - causal=causal, - ) - + encoder.encoder_index = i # <-- will be used in streaming_forward encoders.append(encoder) + cur_downsample = set_downsample_factor(cur_downsample, output_downsampling_factor) + self.encoders = nn.ModuleList(encoders) - self.downsample_output = SimpleDownsample( - max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout, - causal=causal, - ) + def get_chunk_info(self) -> Tuple[int, int]: """ @@ -242,7 +247,6 @@ def forward( - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ - outputs = [] chunk_size, left_context_chunks = self.get_chunk_info() @@ -252,29 +256,28 @@ def forward( else: attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) + for module in self.encoders: + if isinstance(module, Zipformer2Encoder): + i = module.encoder_index # was set in this class's __init__ function. + ds = self.downsampling_factor[i] + x = module( + x, + chunk_size=chunk_size, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + ) + else: + x = module(x) + + x = x[..., :self.encoder_dim[-1]] - x = module( - x, - chunk_size=chunk_size, - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - attn_mask=attn_mask, - ) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 @@ -328,21 +331,6 @@ def _get_attn_mask( logging.info(f"attn_mask = {attn_mask}") return attn_mask - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [outputs[-1]] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - def streaming_forward( self, x: Tensor, @@ -370,31 +358,28 @@ def streaming_forward( of frames in `embeddings` before padding. - updated states """ - outputs = [] new_states = [] layer_offset = 0 - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) + for module in enumerate(self.encoders): + if not isinstance(module, Zipformer2Encoder): + x = module(x) + else: + i = module.encoder_index # was set in this class's __init__ function. + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + new_states += new_layer_states + + x = x[..., :self.encoder_dim[-1]] - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -874,7 +859,9 @@ def forward( r"""Pass the input through the encoder layers in turn. Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). @@ -885,12 +872,16 @@ def forward( Returns: a Tensor with the same shape as src. """ pos_emb = self.encoder_pos(src) - output = src + + num_channels = src.shape[-1] + layer_dim = self.layers[0].embed_dim + if num_channels > layer_dim: + src, bypass = src[..., :layer_dim], src[..., layer_dim:] invertible_layer = random.randint(0, len(self.layers)) for i, mod in enumerate(self.layers): - output = mod( - output, + src = mod( + src, pos_emb, chunk_size=chunk_size, attn_mask=attn_mask, @@ -898,7 +889,10 @@ def forward( randomize=(i == invertible_layer), ) - return output + if num_channels > layer_dim: + src = torch.cat((src, bypass), dim=-1) + + return src def streaming_forward( self, @@ -910,7 +904,7 @@ def streaming_forward( r"""Pass the input through the encoder layers in turn. Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). left_context_len: Number of left context frames. @@ -923,7 +917,10 @@ def streaming_forward( - updated states """ pos_emb = self.encoder_pos(src, left_context_len) - output = src + num_channels = src.shape[-1] + layer_dim = self.layers[0].embed_dim + if num_channels > layer_dim: + src, bypass = src[..., :layer_dim], src[..., layer_dim:] new_states = [] for i, mod in enumerate(self.layers): @@ -936,7 +933,7 @@ def streaming_forward( cached_conv2, ) = states[i * 6 : (i + 1) * 6] ( - output, + src, new_cached_key, new_cached_nonlin_attn, new_cached_val1, @@ -944,7 +941,7 @@ def streaming_forward( new_cached_conv1, new_cached_conv2, ) = mod.streaming_forward( - output, + src, pos_emb, cached_key=cached_key, cached_nonlin_attn=cached_nonlin_attn, @@ -964,7 +961,10 @@ def streaming_forward( new_cached_conv2, ] - return output, new_states + if num_channels > layer_dim: + src = torch.cat((src, bypass), dim=-1) + + return src, new_states class BypassModule(nn.Module): @@ -1025,120 +1025,18 @@ def forward(self, src_orig: Tensor, src: Tensor): return src_orig + (src - src_orig) * bypass_scale -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike, - causal: bool, - ): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout, causal) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds, ::ds] - - src = self.encoder( - src, - chunk_size=chunk_size // ds, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Downsample, go through encoder, upsample, in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); - True means masked position. May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - src_orig = src - src = self.downsample(src) - - src, new_states = self.encoder.streaming_forward( - src, - states=states, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src), new_states - -class SimpleDownsample(torch.nn.Module): +class InvertibleDownsample(torch.nn.Module): """ - Does downsampling with attention, by weighted sum, and a projection.. + Does downsampling in an invertible way, by a factor of two. """ - def __init__( - self, channels: int, downsample: int, dropout: FloatLike, causal: bool + self, channels: int, causal: bool = False, ): - super(SimpleDownsample, self).__init__() + super().__init__() + self.proj = OrthogonalLinearDownsampling(channels * 2) self.causal = causal - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample def forward(self, src: Tensor) -> Tensor: """ @@ -1147,57 +1045,37 @@ def forward(self, src: Tensor) -> Tensor: ( (seq_len+downsample-1)//downsample, batch_size, channels) """ (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - - if self.causal and torch.jit.is_tracing(): - assert ( - pad == 0 - ), f"pad should be zero for exporting streaming models. Given {pad}" - - # If we are exporting a streaming model, then we skip the if statement - if not self.causal or not torch.jit.is_tracing(): - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - - assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans + if seq_len % 2 == 1: + if torch.jit.is_tracing(): + assert ( + not self.causal + ), f"pad should be zero for exporting streaming models. Given {pad}" + src = torch.cat((src, src[-1:]), dim=0) + + src = src.permute(1, 0, 2).reshape(batch_size, seq_len // 2, in_channels * 2) + src = self.proj(src) + src = src.permute(1, 0, 2) # (seq_len // 2, batch_size, in_channels * 2) + return src -class SimpleUpsample(torch.nn.Module): +class InvertibleUpsample(torch.nn.Module): """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. + A very simple form of upsampling that is the inverse of InvertibleDownsampling. """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample + def __init__(self, channels: int): + super().__init__() + self.proj = OrthogonalLinearUpsampling(channels) def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) + ( (seq_len*2), batch_size, num_channels // 2) """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) + src = self.proj(src) + (seq_len, batch_size, in_channels) = src.shape + src = src.permute(1, 0, 2).reshape(batch_size, seq_len * 2, in_channels // 2) + src = src.permute(1, 0, 2) # (seq_len * 2, batch_size, in_channels // 2) return src @@ -1216,7 +1094,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic function to compress large offsets to a smaller range before applying atan(). Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) Args: From ed65a180822c786a8f1d5be4bff080334fb9a8be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 15:27:12 +0800 Subject: [PATCH 0019/1191] Make randomize_scale increase faster. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bb7bfe6e6c..e3685f69bd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -478,7 +478,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.1), (18000.0, 0.1), (40000.0, 1.0)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.1), (5000.0, 0.1), (20000.0, 1.0)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From d35e73ddfeccf5e24f3abf7dad4fa66eda64af7f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 15:33:08 +0800 Subject: [PATCH 0020/1191] Bug fix regarding odd seq_len in downsample --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e3685f69bd..d80a4b06a4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1052,6 +1052,7 @@ def forward(self, src: Tensor) -> Tensor: not self.causal ), f"pad should be zero for exporting streaming models. Given {pad}" src = torch.cat((src, src[-1:]), dim=0) + seq_len += 1 src = src.permute(1, 0, 2).reshape(batch_size, seq_len // 2, in_channels * 2) src = self.proj(src) @@ -2124,7 +2125,7 @@ def _test_zipformer_main(causal: bool = False): left_context_frames=(64,), ) batch_size = 5 - seq_len = 20 + seq_len = 21 # Just make sure the forward pass runs. f = c( torch.randn(seq_len, batch_size, 64), From 5b0bce46a16bc47da30222f1c028be58073c4dee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 16:13:40 +0800 Subject: [PATCH 0021/1191] Various bug fixes; change encoder dimension. --- egs/librispeech/ASR/zipformer/train.py | 3 ++- egs/librispeech/ASR/zipformer/zipformer.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c3763a6282..d967bf1714 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -624,9 +624,10 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: # In the normal configuration, we will downsample once more at the end # by a factor of 2, and most of the encoder stacks will run at a lower # sampling rate. + output_downsampling_factor = 2 encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], + out_channels=max(_to_int_tuple(params.encoder_dim)) // output_downsampling_factor, dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), ) return encoder_embed diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d80a4b06a4..83bad639e2 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -149,7 +149,7 @@ def _to_tuple(x): num_encoders = len(downsampling_factor) cur_downsample = 1 - input_dim = encoder_dim[0] + input_dim = max(encoder_dim) // output_downsampling_factor # caution: some changes we made for this break the streaming, later we'll try to fix this. encoders_downsampling_factors = [ ] @@ -256,10 +256,17 @@ def forward( else: attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + orig_seq_len = x.shape[0] + + def truncate(x, downsampling_factor): + max_len = (orig_seq_len + downsampling_factor - 1) // downsampling_factor + return x[:max_len] if x.shape[0] > max_len else x + for module in self.encoders: if isinstance(module, Zipformer2Encoder): i = module.encoder_index # was set in this class's __init__ function. ds = self.downsampling_factor[i] + x = truncate(x, ds) x = module( x, chunk_size=chunk_size, @@ -276,7 +283,7 @@ def forward( else: x = module(x) - x = x[..., :self.encoder_dim[-1]] + x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): From 0de042ee05ce4a4ea67a8ba2f965952d1cf8c6ab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Jan 2025 17:35:36 +0800 Subject: [PATCH 0022/1191] Remove dropout of pos_scores; have schedule on self_attn score limit --- egs/librispeech/ASR/zipformer/train.py | 4 ++-- egs/librispeech/ASR/zipformer/zipformer.py | 16 +++++----------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index d967bf1714..390367fc32 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1157,13 +1157,13 @@ def save_bad_model(suffix: str = ""): rank=rank, ) - if batch_idx % 100 == 0 and params.use_autocast: + if batch_idx % 25 == 0 and params.use_autocast: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + if cur_grad_scale < 2.0 or (cur_grad_scale < 8.0 and batch_idx % 100 == 0) or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 83bad639e2..21373597ee 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1233,7 +1233,6 @@ def __init__( query_head_dim: int, pos_head_dim: int, dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1241,9 +1240,10 @@ def __init__( self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.name = None # will be overwritten in training code; for diagnostics. + self.attn_score_limit = copy.deepcopy(ScheduledFloat((0.0, 5.0), (40000.0, 20.0))) + key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1345,14 +1345,8 @@ def forward( attn_scores = torch.matmul(q, k) - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: + if True: + # position scores. pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( @@ -1405,7 +1399,7 @@ def forward( # values rather than a regularization method that should be active # under normal circumstances. attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + attn_scores, limit=float(self.attn_score_limit), penalty=1.0e-04, name=self.name ) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) From fb48cfdb8871662f25b3b540a124557ae492a343 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Jan 2025 13:36:22 +0800 Subject: [PATCH 0023/1191] Do two forward steps, no backward step, and amplify at the end with higher schedule. --- egs/librispeech/ASR/zipformer/zipformer.py | 33 +++++++++++++--------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 21373597ee..773a9e692f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.1), (5000.0, 0.1), (20000.0, 1.0)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 4.0)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -622,22 +622,25 @@ def forward( Returns: A tensor which has the same shape as src """ - ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) + ans = self.forward_internal(src, pos_emb, chunk_size, + attn_mask, src_key_padding_mask) if not (randomize and self.training): return ans - scale = float(self.randomize_scale) - #if scale == 0.0: - # return ans - (seq_len, batch_size, emb_dim) = src.shape - t = torch.empty(batch_size, 1, device=src.device).uniform_(0.1, 2.0).clamp_(max=1.0) - # t is random from 0.1 to 1, many elements exactly 1. + # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching + # situation, and we compute an alternative version of x1 (called "x1" in the code) + # that is computed as two steps. We then amplify the difference between "ans" and + # that alternative version of x1, and multiply it by random noise. - v0 = ans - src - xt = src + v0 * t + (seq_len, batch_size, emb_dim) = src.shape + t = torch.empty(batch_size, 1, device=src.device).uniform_(0.1, 0.9) + xt = src + (ans - src) * t ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - vt = ans_t - xt - rand = torch.randn_like(src) * (vt - v0) / t + x1 = xt + (ans_t - xt) * (1. - t) + scale = float(self.randomize_scale) + diff = x1 - ans # this is the difference between a 1-step and a 2-step version of x_1. + # we want 'diff' to be zero. + rand = torch.empty_like(src).uniform_(-scale, scale) * diff return ans + rand @@ -2125,17 +2128,19 @@ def _test_zipformer_main(causal: bool = False): chunk_size=(4,) if causal else (-1,), left_context_frames=(64,), ) + input_dim = 96 // 2 # this makes little sense, it relates to how the code used to work. + batch_size = 5 seq_len = 21 # Just make sure the forward pass runs. f = c( - torch.randn(seq_len, batch_size, 64), + torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) f[0].sum().backward() c.eval() f = c( - torch.randn(seq_len, batch_size, 64), + torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) f # to remove flake8 warnings From 75bd8aeb7237be2bb91787d21df52dd1aaf7bc7f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Jan 2025 19:33:34 +0800 Subject: [PATCH 0024/1191] Divide random differene by t-t**2 (should equalize differences), and use randn not uniform(-1..1) --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 773a9e692f..d9d4b037c1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -637,10 +637,10 @@ def forward( xt = src + (ans - src) * t ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) x1 = xt + (ans_t - xt) * (1. - t) - scale = float(self.randomize_scale) + scale = float(self.randomize_scale) / (t - t**2) diff = x1 - ans # this is the difference between a 1-step and a 2-step version of x_1. # we want 'diff' to be zero. - rand = torch.empty_like(src).uniform_(-scale, scale) * diff + rand = torch.randn_like(src) * scale * diff return ans + rand From 0e65f32a0d1ac711f7e700741e0b1b2f6051ab0b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Jan 2025 23:59:50 +0800 Subject: [PATCH 0025/1191] Reduce final randomize_scale from 4 to 3 and divide warmup period by 4. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d9d4b037c1..07e85ac357 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 4.0)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (5000.0, 3.0)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From 19642fdc8473e34ed7086a06f9a88f4369ef86dc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 12:37:27 +0800 Subject: [PATCH 0026/1191] More efficient OrthogonalLinear; lower scale on invertible loss. --- egs/librispeech/ASR/zipformer/scaling.py | 31 ++++++++++++++++++---- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d8d6764161..6d27fba719 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,11 +565,33 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans +class OrthogonalLinear(nn.Linear): + def __init__(num_channels: int, penalty_scale: FloatLike = 100.0): + # caution: the actual scale of the penalty will be affected by the grad_scale in fp16. + # the "effective scale" will be penalty_scale / grad_scale. + # we'll see whether this matters much in practice. + super().__init__(num_channels, num_channels, bias=False) + self.penalty_scale = penalty_scale + + def forward(self, x: Tensor): + ans = nn.functional.linear(x, self.weight, self.bias) + if self.training and random.random() < 0.5: + weight = self.weight + if weight.shape[0] > weight.shape[1]: + weight = weight.t() + prod = torch.matmul(self.weight, self.weight.t()) + err = prod - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) + eps = 1.0e-10 + penalty = float(self.penalty_scale) * (err ** 2).sum() + ans = with_loss(ans, penalty) + return ans + + def OrthogonalLinearDownsampling(num_channels: int): # returns a parameterized nn.Linear that stays orthogonal, with a special initialization # that is suitable to use when downsampling; we reshape then multiply by this matrix. assert num_channels % 2 == 0 - ans = nn.Linear(num_channels, num_channels, bias=False) + ans = OrthogonalLinear(num_channels) inv_sqrt2 = 2 ** -0.5 N = num_channels // 2 eye = inv_sqrt2 * torch.eye(N) @@ -579,13 +601,13 @@ def OrthogonalLinearDownsampling(num_channels: int): ans.weight[:N, N:] = eye ans.weight[N:, :N] = -eye ans.weight[N:, N:] = eye - return torch.nn.utils.parametrizations.orthogonal(ans) + return ans def OrthogonalLinearUpsampling(num_channels: int): # returns a parameterized nn.Linear that stays orthogonal, with a special initialization # that is suitable to use when downsampling; we multiply by this matrix then reshape. assert num_channels % 2 == 0 - ans = nn.Linear(num_channels, num_channels, bias=False) + ans = OrthogonalLinear(num_channels) inv_sqrt2 = 2 ** -0.5 N = num_channels // 2 eye = inv_sqrt2 * torch.eye(N) @@ -595,8 +617,7 @@ def OrthogonalLinearUpsampling(num_channels: int): ans.weight[:N, N:] = -eye ans.weight[N:, :N] = eye ans.weight[N:, N:] = eye - return torch.nn.utils.parametrizations.orthogonal(ans) - + return ans class ChunkCausalDepthwiseConv1d(torch.nn.Module): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 07e85ac357..126c701d6d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (5000.0, 3.0)), + randomize_scale: FloatLike = 1.0, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -596,7 +596,6 @@ def __init__( max_abs=4.0, ) - def forward( self, src: Tensor, From 68622d64c13a288f3ef9860e6cedaa9420ef4f04 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 12:44:03 +0800 Subject: [PATCH 0027/1191] Increase randomize_scale from 1.0 to 1.5. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 126c701d6d..6a5c7ab752 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = 1.0, + randomize_scale: FloatLike = 1.5, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From f2e4da98daad8a436126547204b41c51dc1129d9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 13:35:28 +0800 Subject: [PATCH 0028/1191] Make it print penalty --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6d27fba719..00bd21a4bb 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -572,6 +572,7 @@ def __init__(num_channels: int, penalty_scale: FloatLike = 100.0): # we'll see whether this matters much in practice. super().__init__(num_channels, num_channels, bias=False) self.penalty_scale = penalty_scale + self.name = None # will be set from training loop. for printing penalty. def forward(self, x: Tensor): ans = nn.functional.linear(x, self.weight, self.bias) @@ -583,7 +584,7 @@ def forward(self, x: Tensor): err = prod - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) eps = 1.0e-10 penalty = float(self.penalty_scale) * (err ** 2).sum() - ans = with_loss(ans, penalty) + ans = with_loss(ans, penalty, self.name) return ans From cfd892caf87fc238ecbb063a054c623c627def2a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 13:37:01 +0800 Subject: [PATCH 0029/1191] decrease randomize_scale. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6a5c7ab752..86ac14a6c9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = 1.5, + randomize_scale: FloatLike = 0.66, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From c03567cdbe2d7fab777d28ff373f06f0f757dfb8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 15:17:35 +0800 Subject: [PATCH 0030/1191] fix that will make no difference, OrthogonalLinear --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 00bd21a4bb..a39eb5b5a6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -580,7 +580,7 @@ def forward(self, x: Tensor): weight = self.weight if weight.shape[0] > weight.shape[1]: weight = weight.t() - prod = torch.matmul(self.weight, self.weight.t()) + prod = torch.matmul(weight, weight.t()) err = prod - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) eps = 1.0e-10 penalty = float(self.penalty_scale) * (err ** 2).sum() From 2a7b62ceda5020d187508b40ed17960c7b424909 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 15:57:10 +0800 Subject: [PATCH 0031/1191] More debug; change how noise interacts with diffs and put some global element. --- egs/librispeech/ASR/zipformer/zipformer.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 86ac14a6c9..a9adf47741 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,10 +485,11 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = 0.66, + randomize_scale: FloatLike = 1.0, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim + self.name = None # will be set from training loop self.randomize_scale = copy.deepcopy(randomize_scale) # self.bypass implements layer skipping as well as bypass; see its default values. @@ -636,13 +637,24 @@ def forward( xt = src + (ans - src) * t ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) x1 = xt + (ans_t - xt) * (1. - t) - scale = float(self.randomize_scale) / (t - t**2) - diff = x1 - ans # this is the difference between a 1-step and a 2-step version of x_1. - # we want 'diff' to be zero. - rand = torch.randn_like(src) * scale * diff + diff = (x1 - ans) / (t - t**2) + + diff_sqscale = (diff ** 2).mean(dim=2, keepdim=True) + G = 0.2 # scale on the global-mean part of the random-noise scale. + scale = float(self.randomize_scale) + diff_scale = (scale * G) * (diff_sqscale ** 2).mean().sqrt() + (scale * (1. - G)) * diff_sqscale.sqrt() + rand = torch.randn_like(src) * diff_scale + if random.random() < 0.01 or __name__ == '__main__': + # logging output + diff_scale = (diff ** 2).mean(dim=(0, 2)).sqrt() + t_flat = t.flatten() + values, indexes = t_flat.sort() + logging.info(f"name={self.name}: diff_scale={diff_scale[indexes]}, t={values}; global-scale={(diff_sqscale**2).mean().sqrt()}") + return ans + rand + def forward_internal( self, src: Tensor, From 63596cd3f39d99d58538b3710e91ef9435209492 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Jan 2025 18:05:28 +0800 Subject: [PATCH 0032/1191] Bug fix. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a9adf47741..b8a824acdf 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -642,14 +642,14 @@ def forward( diff_sqscale = (diff ** 2).mean(dim=2, keepdim=True) G = 0.2 # scale on the global-mean part of the random-noise scale. scale = float(self.randomize_scale) - diff_scale = (scale * G) * (diff_sqscale ** 2).mean().sqrt() + (scale * (1. - G)) * diff_sqscale.sqrt() + diff_scale = (scale * G) * diff_sqscale.mean().sqrt() + (scale * (1. - G)) * diff_sqscale.sqrt() rand = torch.randn_like(src) * diff_scale if random.random() < 0.01 or __name__ == '__main__': # logging output diff_scale = (diff ** 2).mean(dim=(0, 2)).sqrt() t_flat = t.flatten() values, indexes = t_flat.sort() - logging.info(f"name={self.name}: diff_scale={diff_scale[indexes]}, t={values}; global-scale={(diff_sqscale**2).mean().sqrt()}") + logging.info(f"name={self.name}: diff_scale={diff_scale[indexes]}, t={values}; global-scale={diff_sqscale.mean().sqrt()}") return ans + rand From 38dcd9cf14bd234cf5644210b8a097f73786235a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 13:09:15 +0800 Subject: [PATCH 0033/1191] Fix to scaling.py --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a39eb5b5a6..adf635c95c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -566,7 +566,7 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans class OrthogonalLinear(nn.Linear): - def __init__(num_channels: int, penalty_scale: FloatLike = 100.0): + def __init__(self, num_channels: int, penalty_scale: FloatLike = 100.0): # caution: the actual scale of the penalty will be affected by the grad_scale in fp16. # the "effective scale" will be penalty_scale / grad_scale. # we'll see whether this matters much in practice. @@ -574,6 +574,10 @@ def __init__(num_channels: int, penalty_scale: FloatLike = 100.0): self.penalty_scale = penalty_scale self.name = None # will be set from training loop. for printing penalty. + # by default, initialize to the identity. + with torch.no_grad(): + self.weight[:] = torch.eye(num_channels) + def forward(self, x: Tensor): ans = nn.functional.linear(x, self.weight, self.bias) if self.training and random.random() < 0.5: From 53b5ffbaf3a40f9c60f46aa3aa1820741ba7f116 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 16:19:28 +0800 Subject: [PATCH 0034/1191] take scaling.py from branch deterministic_invertible1frontend: randomized way of enforcing invertibility loss. --- egs/librispeech/ASR/zipformer/scaling.py | 39 +++++++++++++++--------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index adf635c95c..7e23a5e0be 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -17,6 +17,7 @@ import logging import math +import copy import random from typing import Optional, Tuple, Union @@ -566,12 +567,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans class OrthogonalLinear(nn.Linear): - def __init__(self, num_channels: int, penalty_scale: FloatLike = 100.0): - # caution: the actual scale of the penalty will be affected by the grad_scale in fp16. - # the "effective scale" will be penalty_scale / grad_scale. - # we'll see whether this matters much in practice. + def __init__(self, num_channels: int, penalty_scale: FloatLike = 2.0): super().__init__(num_channels, num_channels, bias=False) - self.penalty_scale = penalty_scale + self.penalty_scale = copy.deepcopy(penalty_scale) self.name = None # will be set from training loop. for printing penalty. # by default, initialize to the identity. @@ -580,15 +578,22 @@ def __init__(self, num_channels: int, penalty_scale: FloatLike = 100.0): def forward(self, x: Tensor): ans = nn.functional.linear(x, self.weight, self.bias) - if self.training and random.random() < 0.5: - weight = self.weight - if weight.shape[0] > weight.shape[1]: - weight = weight.t() - prod = torch.matmul(weight, weight.t()) - err = prod - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) - eps = 1.0e-10 - penalty = float(self.penalty_scale) * (err ** 2).sum() - ans = with_loss(ans, penalty, self.name) + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return ans + penalty_scale = float(self.penalty_scale) + if penalty_scale == 0.0: + return ans + ans_scale = (ans ** 2).mean() + weight = self.weight + if weight.shape[0] > weight.shape[1]: + weight = weight.t() + prod = torch.matmul(weight, weight.t()) + err = prod - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) + err = (err ** 2).mean() + noise_scale = penalty_scale * ans_scale * err + if random.random() < 0.001 or __name__ == '__main__': + logging.info(f"noise_scale = {noise_scale.item()} = {penalty_scale}*{ans_scale.item()}*{err.item()}") + ans = ans + noise_scale * torch.randn_like(ans) return ans @@ -1952,6 +1957,11 @@ def isclose(a, b): # storage of it. assert isclose(x1.grad, x2.grad) +def _test_orthogonal_linear(): + for t in (OrthogonalLinear, OrthogonalLinearUpsampling, OrthogonalLinearDownsampling): + m = t(128) + m(torch.randn(30, 2, 128)) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) @@ -1966,3 +1976,4 @@ def isclose(a, b): _test_swooshr_deriv() _test_swooshl_deriv() _test_activation_dropout_and_linear() + _test_orthogonal_linear() From 35830256670c4c8feae60c9543eb3f456771d75f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 17:06:00 +0800 Subject: [PATCH 0035/1191] Fix scaling2->scaling --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b8a824acdf..deb36f1842 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -25,7 +25,7 @@ import torch from encoder_interface import EncoderInterface -from scaling2 import ( +from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. OrthogonalLinearUpsampling, OrthogonalLinearDownsampling, From 60e69ef464b213a5d99573717c882a8ceeb55adf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 17:24:10 +0800 Subject: [PATCH 0036/1191] Add name to debug output --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7e23a5e0be..a8e9590890 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -592,7 +592,7 @@ def forward(self, x: Tensor): err = (err ** 2).mean() noise_scale = penalty_scale * ans_scale * err if random.random() < 0.001 or __name__ == '__main__': - logging.info(f"noise_scale = {noise_scale.item()} = {penalty_scale}*{ans_scale.item()}*{err.item()}") + logging.info(f"{self.name}: noise_scale = {noise_scale.item()} = {penalty_scale}*{ans_scale.item()}*{err.item()}") ans = ans + noise_scale * torch.randn_like(ans) return ans From f214fb0f607be10b6a34e3f650dace474819b08b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 17:33:43 +0800 Subject: [PATCH 0037/1191] Allow the orthogonal matrices to be scaled by a scalar --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a8e9590890..41e708b213 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -587,8 +587,8 @@ def forward(self, x: Tensor): weight = self.weight if weight.shape[0] > weight.shape[1]: weight = weight.t() - prod = torch.matmul(weight, weight.t()) - err = prod - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) + prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. + err = prod / prod.diag().mean() - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) err = (err ** 2).mean() noise_scale = penalty_scale * ans_scale * err if random.random() < 0.001 or __name__ == '__main__': From 9e415c0edb1fa3f2f01001ea30460e346293ea86 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 19:29:49 +0800 Subject: [PATCH 0038/1191] change scaling.py to use WithLoss for orthogonality, now allowing scalar scale. --- egs/librispeech/ASR/zipformer/scaling.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 41e708b213..9604cefd6b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -567,7 +567,7 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans class OrthogonalLinear(nn.Linear): - def __init__(self, num_channels: int, penalty_scale: FloatLike = 2.0): + def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): super().__init__(num_channels, num_channels, bias=False) self.penalty_scale = copy.deepcopy(penalty_scale) self.name = None # will be set from training loop. for printing penalty. @@ -583,20 +583,17 @@ def forward(self, x: Tensor): penalty_scale = float(self.penalty_scale) if penalty_scale == 0.0: return ans - ans_scale = (ans ** 2).mean() weight = self.weight if weight.shape[0] > weight.shape[1]: weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. err = prod / prod.diag().mean() - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) - err = (err ** 2).mean() - noise_scale = penalty_scale * ans_scale * err + err = (err ** 2).sum() + ans = with_loss(ans, err * float(self.penalty_scale), self.name) if random.random() < 0.001 or __name__ == '__main__': - logging.info(f"{self.name}: noise_scale = {noise_scale.item()} = {penalty_scale}*{ans_scale.item()}*{err.item()}") - ans = ans + noise_scale * torch.randn_like(ans) + logging.info(f"{self.name}: dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") return ans - def OrthogonalLinearDownsampling(num_channels: int): # returns a parameterized nn.Linear that stays orthogonal, with a special initialization # that is suitable to use when downsampling; we reshape then multiply by this matrix. From ec631abc21215cb0551534ada61508a8cf35e053 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jan 2025 21:32:21 +0800 Subject: [PATCH 0039/1191] Detach mean of prod in OrthognalLinear --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9604cefd6b..d12e1b2e31 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -587,7 +587,8 @@ def forward(self, x: Tensor): if weight.shape[0] > weight.shape[1]: weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. - err = prod / prod.diag().mean() - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) + # detach the mean because in fp16 it may overflow; it doesn't affect the stable point. + err = prod / prod.diag().mean().detach() - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) err = (err ** 2).sum() ans = with_loss(ans, err * float(self.penalty_scale), self.name) if random.random() < 0.001 or __name__ == '__main__': From 775836f64f1adf02bdaed7ca3dd36e28ac131970 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jan 2025 00:02:35 +0800 Subject: [PATCH 0040/1191] Allow orthogonality to have different scalar values in a differently implemented way --- egs/librispeech/ASR/zipformer/scaling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d12e1b2e31..d8851c53d2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -570,6 +570,8 @@ class OrthogonalLinear(nn.Linear): def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): super().__init__(num_channels, num_channels, bias=False) self.penalty_scale = copy.deepcopy(penalty_scale) + self.max_product_scale = 100.0 + self.product_scale = nn.Parameter(torch.tensor(1.0)) self.name = None # will be set from training loop. for printing penalty. # by default, initialize to the identity. @@ -588,11 +590,12 @@ def forward(self, x: Tensor): weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. # detach the mean because in fp16 it may overflow; it doesn't affect the stable point. - err = prod / prod.diag().mean().detach() - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) + product_scale = limit_param_value(self.product_scale, min=0.1, max=self.max_product_scale) + err = prod * product_scale - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) err = (err ** 2).sum() - ans = with_loss(ans, err * float(self.penalty_scale), self.name) + ans = with_loss(ans, err * penalty_scale, self.name) if random.random() < 0.001 or __name__ == '__main__': - logging.info(f"{self.name}: dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") + logging.info(f"{self.name}: 1/product_scale={1/self.product_scale}, dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") return ans def OrthogonalLinearDownsampling(num_channels: int): From 1809018bdecad93f2829d035dd2b1f814d84bfe8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jan 2025 00:07:43 +0800 Subject: [PATCH 0041/1191] Do it in log space --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d8851c53d2..3170b406a5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -571,7 +571,7 @@ def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): super().__init__(num_channels, num_channels, bias=False) self.penalty_scale = copy.deepcopy(penalty_scale) self.max_product_scale = 100.0 - self.product_scale = nn.Parameter(torch.tensor(1.0)) + self.product_scale = nn.Parameter(torch.tensor(0.0)) self.name = None # will be set from training loop. for printing penalty. # by default, initialize to the identity. @@ -590,12 +590,12 @@ def forward(self, x: Tensor): weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. # detach the mean because in fp16 it may overflow; it doesn't affect the stable point. - product_scale = limit_param_value(self.product_scale, min=0.1, max=self.max_product_scale) + product_scale = limit_param_value(self.product_scale.exp(), min=0.1, max=self.max_product_scale) err = prod * product_scale - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) err = (err ** 2).sum() ans = with_loss(ans, err * penalty_scale, self.name) if random.random() < 0.001 or __name__ == '__main__': - logging.info(f"{self.name}: 1/product_scale={1/self.product_scale}, dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") + logging.info(f"{self.name}: 1/product_scale={1/product_scale}, dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") return ans def OrthogonalLinearDownsampling(num_channels: int): From a0a49ad91a4185153b402f3fec62ece29ba2c858 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jan 2025 12:26:12 +0800 Subject: [PATCH 0042/1191] compute penalty of weight in OrthogonalLinear without autocast --- egs/librispeech/ASR/zipformer/scaling.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3170b406a5..af7807751b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -589,11 +589,13 @@ def forward(self, x: Tensor): if weight.shape[0] > weight.shape[1]: weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. - # detach the mean because in fp16 it may overflow; it doesn't affect the stable point. - product_scale = limit_param_value(self.product_scale.exp(), min=0.1, max=self.max_product_scale) - err = prod * product_scale - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) - err = (err ** 2).sum() - ans = with_loss(ans, err * penalty_scale, self.name) + with torch.cuda.amp.autocast(enabled=False): + # disabling autocast is to prevent the grad of product_scale from overflowing. + # detach the mean because in fp16 it may overflow; it doesn't affect the stable point. + product_scale = limit_param_value(self.product_scale.exp(), min=0.1, max=self.max_product_scale) + err = prod * product_scale - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) + err = (err ** 2).sum() + ans = with_loss(ans, err * penalty_scale, self.name) if random.random() < 0.001 or __name__ == '__main__': logging.info(f"{self.name}: 1/product_scale={1/product_scale}, dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") return ans From b280eca2edd996c31e0e5b0dba8ad1c739d36c17 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jan 2025 13:44:32 +0800 Subject: [PATCH 0043/1191] change dtype in WithLoss --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index af7807751b..38b9461ebd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1182,6 +1182,7 @@ class WithLoss(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, y: Tensor, name: str): ctx.y_shape = y.shape + ctx.dtype = y.dtype if random.random() < 0.002 and name is not None: loss_sum = y.sum().item() logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") @@ -1191,7 +1192,7 @@ def forward(ctx, x: Tensor, y: Tensor, name: str): def backward(ctx, ans_grad: Tensor): return ( ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + torch.ones(ctx.y_shape, dtype=ctx.dtype, device=ans_grad.device), None, ) From 272cfdc182118c230bd6a007d87b27c027a32691 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jan 2025 13:44:32 +0800 Subject: [PATCH 0044/1191] change dtype in WithLoss --- egs/librispeech/ASR/zipformer/scaling.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 38b9461ebd..1f4df2961a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -570,8 +570,7 @@ class OrthogonalLinear(nn.Linear): def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): super().__init__(num_channels, num_channels, bias=False) self.penalty_scale = copy.deepcopy(penalty_scale) - self.max_product_scale = 100.0 - self.product_scale = nn.Parameter(torch.tensor(0.0)) + self.min_product_scale = 0.01 self.name = None # will be set from training loop. for printing penalty. # by default, initialize to the identity. @@ -589,15 +588,19 @@ def forward(self, x: Tensor): if weight.shape[0] > weight.shape[1]: weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. - with torch.cuda.amp.autocast(enabled=False): - # disabling autocast is to prevent the grad of product_scale from overflowing. - # detach the mean because in fp16 it may overflow; it doesn't affect the stable point. - product_scale = limit_param_value(self.product_scale.exp(), min=0.1, max=self.max_product_scale) - err = prod * product_scale - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype) - err = (err ** 2).sum() - ans = with_loss(ans, err * penalty_scale, self.name) + with torch.no_grad(): + alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) + alpha = alpha.clamp_(max=1. / self.min_product_scale) + + # following is equivalent to penalty_scale ((prod * alpha - I) ** + # 2).sum(), but more memory and compute efficient. + err = ((prod ** 2).sum() * (alpha ** 2 * penalty_scale) + + (-2 * alpha * penalty_scale) * prod.diag().sum() + + (prod.shape[0] ** 2 * penalty_scale)) + + ans = with_loss(ans, err, self.name) if random.random() < 0.001 or __name__ == '__main__': - logging.info(f"{self.name}: 1/product_scale={1/product_scale}, dim={weight.shape}, avg_err = {err*float(self.penalty_scale)}={err}*{float(self.penalty_scale)}") + logging.info(f"{self.name}: product_scale={1/alpha}, dim={weight.shape}, avg_err = {err} = {err/penalty_scale}*{penalty_scale}") return ans def OrthogonalLinearDownsampling(num_channels: int): From 8053474fa0aedc861ace1095d9c75221d464b219 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jan 2025 16:28:20 +0800 Subject: [PATCH 0045/1191] remove some balancers and whiteners and put the BiasNorm after the invertibility stuff in ZipformerEncoderLayer. --- egs/librispeech/ASR/zipformer/zipformer.py | 34 ++-------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index deb36f1842..44ad9fafa1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -539,14 +539,6 @@ def __init__( self.norm = BiasNorm(embed_dim) - self.balancer1 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.2, - max_abs=4.0, - ) # balancer for output of NonlinAttentionModule self.balancer_na = Balancer( @@ -581,22 +573,6 @@ def __init__( prob=0.05, ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.balancer2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.1, - max_abs=4.0, - ) - def forward( self, src: Tensor, @@ -625,7 +601,7 @@ def forward( ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) if not (randomize and self.training): - return ans + return self.norm(ans) # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching # situation, and we compute an alternative version of x1 (called "x1" in the code) @@ -651,7 +627,7 @@ def forward( values, indexes = t_flat.sort() logging.info(f"name={self.name}: diff_scale={diff_scale[indexes]}, t={values}; global-scale={diff_sqscale.mean().sqrt()}") - return ans + rand + return self.norm(ans + rand) @@ -710,14 +686,8 @@ def forward_internal( ) src = src + self.balancer_ff3(self.feed_forward3(src)) - src = self.balancer1(src) - src = self.norm(src) - src = self.bypass(src_orig, src) - src = self.balancer2(src) - src = self.whiten(src) - return src def streaming_forward( From 8ae092b299b5385e5d41184019f7d4ba29420feb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jan 2025 11:14:57 +0800 Subject: [PATCH 0046/1191] take zipformer.py from deterministic_invertible5frontend: change attn_score penalty schedule and prob, and move sqrt out of addition for invertibility penalty --- egs/librispeech/ASR/zipformer/zipformer.py | 26 ++++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 44ad9fafa1..188d83324b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = 1.0, + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.5), (20000.0, 0.66)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -600,7 +600,7 @@ def forward( """ ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - if not (randomize and self.training): + if torch.jit.is_scripting() or torch.jit.is_tracing() or not (randomize and self.training): return self.norm(ans) # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching @@ -611,20 +611,21 @@ def forward( (seq_len, batch_size, emb_dim) = src.shape t = torch.empty(batch_size, 1, device=src.device).uniform_(0.1, 0.9) xt = src + (ans - src) * t + # ans_t is the network evaluated at t. it's interpreted as xt + vt. ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) x1 = xt + (ans_t - xt) * (1. - t) diff = (x1 - ans) / (t - t**2) diff_sqscale = (diff ** 2).mean(dim=2, keepdim=True) - G = 0.2 # scale on the global-mean part of the random-noise scale. + G = 0.1 # scale on the global-mean part of the random-noise scale. scale = float(self.randomize_scale) - diff_scale = (scale * G) * diff_sqscale.mean().sqrt() + (scale * (1. - G)) * diff_sqscale.sqrt() + with torch.cuda.amp.autocast(enabled=False): + diff_scale = ((scale * G) * diff_sqscale.to(torch.float).mean() + (scale * (1. - G)) * diff_sqscale).sqrt() rand = torch.randn_like(src) * diff_scale - if random.random() < 0.01 or __name__ == '__main__': + if random.random() < 0.001 or __name__ == '__main__': # logging output - diff_scale = (diff ** 2).mean(dim=(0, 2)).sqrt() - t_flat = t.flatten() - values, indexes = t_flat.sort() + diff_scale = (diff ** 2).mean(dim=(0, 2)).sqrt() # mean over all non-batch dims. + values, indexes = t.flatten().sort() logging.info(f"name={self.name}: diff_scale={diff_scale[indexes]}, t={values}; global-scale={diff_sqscale.mean().sqrt()}") return self.norm(ans + rand) @@ -869,7 +870,7 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - invertible_layer = random.randint(0, len(self.layers)) + randomize_layer = random.randint(0, len(self.layers) - 1) for i, mod in enumerate(self.layers): src = mod( src, @@ -877,7 +878,7 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - randomize=(i == invertible_layer), + randomize=(i == randomize_layer), ) if num_channels > layer_dim: @@ -1226,7 +1227,8 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.attn_score_limit = copy.deepcopy(ScheduledFloat((0.0, 5.0), (40000.0, 20.0))) + self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) + self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1369,7 +1371,7 @@ def forward( if torch.jit.is_scripting() or torch.jit.is_tracing(): pass - elif self.training and random.random() < 0.1: + elif self.training and random.random() < float(self.attn_score_penalty_prob): # This is a harder way of limiting the attention scores to not be # too large. It incurs a penalty if any of them has an absolute # value greater than 50.0. this should be outside the normal range From 7ded67d93032b0de8749bc1ee2c2372f8c461857 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jan 2025 20:53:06 +0800 Subject: [PATCH 0047/1191] Reduce randomize_scale from 1.5..0.6 to 1.0.0.5. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 188d83324b..8ab12dfab8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.5), (20000.0, 0.66)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.5)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From 1d433397880c1fac31ab5660c74010c00d481c5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jan 2025 21:01:58 +0800 Subject: [PATCH 0048/1191] Adjust noise scale depending on the probability with which we did the randomization here. --- egs/librispeech/ASR/zipformer/zipformer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8ab12dfab8..7bda45ded0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.5)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.5), (10000.0, 0.25)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -580,7 +580,7 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - randomize: bool = False, # do the invertibility-encouraging randomization if True. + randomize_factor: float = 0.0, # will be 1/(num-layers-this-stack) if randomizing, else 0. ) -> Tensor: """ Pass the input through the encoder layer. @@ -600,7 +600,7 @@ def forward( """ ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - if torch.jit.is_scripting() or torch.jit.is_tracing() or not (randomize and self.training): + if torch.jit.is_scripting() or torch.jit.is_tracing() or not (self.training and randomize_factor != 0.0): return self.norm(ans) # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching @@ -618,7 +618,7 @@ def forward( diff_sqscale = (diff ** 2).mean(dim=2, keepdim=True) G = 0.1 # scale on the global-mean part of the random-noise scale. - scale = float(self.randomize_scale) + scale = randomize_factor * float(self.randomize_scale) with torch.cuda.amp.autocast(enabled=False): diff_scale = ((scale * G) * diff_sqscale.to(torch.float).mean() + (scale * (1. - G)) * diff_sqscale).sqrt() rand = torch.randn_like(src) * diff_scale @@ -870,7 +870,8 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - randomize_layer = random.randint(0, len(self.layers) - 1) + L = len(self.layers) + randomize_layer = random.randint(0, L - 1) for i, mod in enumerate(self.layers): src = mod( src, @@ -878,8 +879,11 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - randomize=(i == randomize_layer), + randomize_factor=L ** -0.5 if i == randomize_layer else 0, ) + # the L ** -0.5 factor assumes that the "penalty" we pay in the loss + # will be proportioal to the square of the stddev, i.e. proportional + # to the noise variance. if num_channels > layer_dim: src = torch.cat((src, bypass), dim=-1) From 3963a28ebc4174acbc4a3953d6a3835832ceb526 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jan 2025 22:59:55 +0800 Subject: [PATCH 0049/1191] Cosmetic fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1f4df2961a..e005801c69 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -596,7 +596,7 @@ def forward(self, x: Tensor): # 2).sum(), but more memory and compute efficient. err = ((prod ** 2).sum() * (alpha ** 2 * penalty_scale) + (-2 * alpha * penalty_scale) * prod.diag().sum() + - (prod.shape[0] ** 2 * penalty_scale)) + (prod.shape[0] * penalty_scale)) ans = with_loss(ans, err, self.name) if random.random() < 0.001 or __name__ == '__main__': From 49ca642a44949f39c14b6f451ecc0ce9969e131d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Jan 2025 12:49:42 +0800 Subject: [PATCH 0050/1191] Make it automatically add inf check hooks. --- egs/librispeech/ASR/zipformer/train.py | 2 ++ icefall/hooks.py | 35 ++++++++++++++++---------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 390367fc32..745de767e4 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1169,6 +1169,8 @@ def save_bad_model(suffix: str = ""): if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() diff --git a/icefall/hooks.py b/icefall/hooks.py index 1c5bd2ae68..e27b71c0a1 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -39,28 +39,37 @@ def register_inf_check_hooks(model: nn.Module) -> None: # default param _name is a way to capture the current value of the variable "name". def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): - if not torch.isfinite(_output.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output is not finite: {_output}" - ) + try: + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.output is not finite: {_output}" + ) + except RuntimeError: # e.g. CUDA out of memory + pass elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): o = o[0] if not isinstance(o, Tensor): continue - if not torch.isfinite(o.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output[{i}] is not finite: {_output}" - ) - + try: + if not torch.isfinite(o.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.output[{i}] is not finite: {_output}" + ) + except RuntimeError: # e.g. CUDA out of memory + pass # default param _name is a way to capture the current value of the variable "name". def backward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): - if not torch.isfinite(_output.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.grad is not finite" # ": {_output}" - ) + try: + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.grad is not finite" # ": {_output}" + ) + except RuntimeError: # e.g. CUDA out of memory + pass + elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): From 9ec10873ec1e75a6387d443485ff3632244bf410 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Jan 2025 13:00:02 +0800 Subject: [PATCH 0051/1191] change where penalty_scale is applied to make overflow of grads less likely in fp16. --- egs/librispeech/ASR/zipformer/scaling.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e005801c69..43d9ec3168 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -588,19 +588,24 @@ def forward(self, x: Tensor): if weight.shape[0] > weight.shape[1]: weight = weight.t() prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. + # could include penalty_scale later on, but we do it at this point to make overflow of + # grads less likely (because they are aggregated earlier on, via sum()). + prod = scale_grad(prod, penalty_scale) with torch.no_grad(): alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) alpha = alpha.clamp_(max=1. / self.min_product_scale) # following is equivalent to penalty_scale ((prod * alpha - I) ** # 2).sum(), but more memory and compute efficient. - err = ((prod ** 2).sum() * (alpha ** 2 * penalty_scale) + - (-2 * alpha * penalty_scale) * prod.diag().sum() + - (prod.shape[0] * penalty_scale)) + err = ((prod ** 2).sum() * (alpha ** 2) + + (-2 * alpha) * prod.diag().sum() + + prod.shape[0]) ans = with_loss(ans, err, self.name) if random.random() < 0.001 or __name__ == '__main__': - logging.info(f"{self.name}: product_scale={1/alpha}, dim={weight.shape}, avg_err = {err} = {err/penalty_scale}*{penalty_scale}") + with torch.no_grad(): + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"{self.name}: product_scale={1/alpha}, dim={weight.shape}, avg_err = {err} * {penalty_scale} = {err*penalty_scale}, ans-rms={ans_rms}") return ans def OrthogonalLinearDownsampling(num_channels: int): From 9f25f1ce694b0822b47a71ff51bdf17070796705 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Jan 2025 00:00:45 +0800 Subject: [PATCH 0052/1191] print different things in log message --- egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7bda45ded0..83460a33ec 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -622,11 +622,11 @@ def forward( with torch.cuda.amp.autocast(enabled=False): diff_scale = ((scale * G) * diff_sqscale.to(torch.float).mean() + (scale * (1. - G)) * diff_sqscale).sqrt() rand = torch.randn_like(src) * diff_scale - if random.random() < 0.001 or __name__ == '__main__': + if random.random() < 0.01 or __name__ == '__main__': # logging output - diff_scale = (diff ** 2).mean(dim=(0, 2)).sqrt() # mean over all non-batch dims. - values, indexes = t.flatten().sort() - logging.info(f"name={self.name}: diff_scale={diff_scale[indexes]}, t={values}; global-scale={diff_sqscale.mean().sqrt()}") + ans_scale = (ans ** 2).mean().sqrt() + vt_scale = ((ans - src) ** 2).mean().sqrt() + logging.info(f"name={self.name}: ans_scale={ans_scale}, vt_scale={vt_scale}, diff-scale={diff_sqscale.mean().sqrt()}") return self.norm(ans + rand) From 9e910a845dc0a6a80d1af3ecc7cdca6c9a7e0ebf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Jan 2025 16:30:36 +0800 Subject: [PATCH 0053/1191] Fix importance-sampling factor, was inverted and sqrted --- egs/librispeech/ASR/zipformer/zipformer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 83460a33ec..3294a4d767 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -580,7 +580,9 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - randomize_factor: float = 0.0, # will be 1/(num-layers-this-stack) if randomizing, else 0. + randomize_factor: float = 0.0, # will be 1/(probability with which we + # randomized this layer) if randomizing, + # else 0. ) -> Tensor: """ Pass the input through the encoder layer. @@ -870,8 +872,13 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] + + randomize_proportion = 0.25 L = len(self.layers) - randomize_layer = random.randint(0, L - 1) + # int(...) rounds down. we'll only randomize >= 2 layers if L >= 8. + num_randomize = max(1, int(0.5 + L * randomize_proportion)) + randomize_layer = [ True ] * num_randomize + [ False ] * (L - num_randomize) + random.shuffle(randomize_layer) for i, mod in enumerate(self.layers): src = mod( src, @@ -879,11 +886,10 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - randomize_factor=L ** -0.5 if i == randomize_layer else 0, + randomize_factor=(L / num_randomize) if randomize_layer[i] else 0, ) - # the L ** -0.5 factor assumes that the "penalty" we pay in the loss - # will be proportioal to the square of the stddev, i.e. proportional - # to the noise variance. + # randomize_factor can be viewed as a simple version of an + # importance-sampling factor. if num_channels > layer_dim: src = torch.cat((src, bypass), dim=-1) From fa6606cdb79ad4c3d200404ae7e1cebeee906b33 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Jan 2025 23:34:14 +0800 Subject: [PATCH 0054/1191] reduce randomize_scale from 0.5->0.2 to 0.2->0.1 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3294a4d767..22c6a2d0c9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.5), (10000.0, 0.25)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.2), (10000.0, 0.1)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From 964a869f16a88d82c337712a89d2501d1203d6c2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 00:23:43 +0800 Subject: [PATCH 0055/1191] Restore upper end of randomize_scale from .2 to 0.5 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 22c6a2d0c9..a4720944bd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -485,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.2), (10000.0, 0.1)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 0.5), (10000.0, 0.1)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim From a5b97ce24336933fcbf6cc7445171f7a916d3bfe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Jan 2025 21:52:20 +0800 Subject: [PATCH 0056/1191] Add a learnable epsilon to BiasNorm, I believe this should make it easier to ensure it is invertible, early in training. --- egs/librispeech/ASR/zipformer/scaling.py | 49 ++++++++++++++++-------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 43d9ec3168..c28b0a8fb4 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -375,6 +375,7 @@ def forward( ctx, x: Tensor, bias: Tensor, + log_eps: Tensor, log_scale: Tensor, channel_dim: int, store_output_for_backprop: bool, @@ -387,7 +388,8 @@ def forward( for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + torch.mean((x - bias) ** 2 + log_eps.exp(), + dim=channel_dim, keepdim=True) ** -0.5 ) * log_scale.exp() ans = x * scales ctx.save_for_backward( @@ -395,28 +397,33 @@ def forward( scales.detach(), bias.detach(), log_scale.detach(), + log_eps.detach(), ) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors + ans_or_x, scales, bias, log_scale, log_eps = ctx.saved_tensors if ctx.store_output_for_backprop: x = ans_or_x / scales else: x = ans_or_x x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = ( - torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None + with torch.cuda.amp.autocast(enabled=False): + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + log_eps.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = ( + torch.mean((x - bias) ** 2 + log_eps.exp(), + dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + + return x.grad, bias.grad.flatten(), log_eps.grad, log_scale.grad, None, None class BiasNorm(torch.nn.Module): @@ -430,7 +437,8 @@ class BiasNorm(torch.nn.Module): LayerNorm are required to allow it to do this. Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) + computing the scale for normalization, in addition to a separate trainable + "eps" parameter, learned in log-space. We also give it a (scalar) trainable scale on the output. @@ -464,6 +472,7 @@ def __init__( self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) + self.log_eps = nn.Parameter(torch.tensor(0.0)) self.log_scale_min = log_scale_min self.log_scale_max = log_scale_max @@ -481,7 +490,8 @@ def forward(self, x: Tensor) -> Tensor: for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + torch.mean((x - bias) ** 2 + self.log_eps.exp(), + dim=channel_dim, keepdim=True) ** -0.5 ) * self.log_scale.exp() return x * scales @@ -491,12 +501,19 @@ def forward(self, x: Tensor) -> Tensor: max=float(self.log_scale_max), training=self.training, ) + log_eps = limit_param_value( + self.log_eps, + min=-5, max=5, # mainly to prevent infinities and zeroes + training=self.training, + ) return BiasNormFunction.apply( - x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + x, self.bias, self.log_eps, log_scale, self.channel_dim, self.store_output_for_backprop ) + + def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear From a683998e2f57db51df93afae344f5af72c47c4c9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 15:53:44 +0800 Subject: [PATCH 0057/1191] Change how upsampling and downsampling are done; have max_proj_dim=2*max(encoder_dim) --- egs/librispeech/ASR/zipformer/scaling.py | 35 ++++------- egs/librispeech/ASR/zipformer/zipformer.py | 68 +++++++++++++++++----- 2 files changed, 62 insertions(+), 41 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c28b0a8fb4..d9463b5fcc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -625,36 +625,21 @@ def forward(self, x: Tensor): logging.info(f"{self.name}: product_scale={1/alpha}, dim={weight.shape}, avg_err = {err} * {penalty_scale} = {err*penalty_scale}, ans-rms={ans_rms}") return ans -def OrthogonalLinearDownsampling(num_channels: int): +def OrthogonalLinearSpecial(num_channels: int, penalty_scale: float = 1000.0): # returns a parameterized nn.Linear that stays orthogonal, with a special initialization # that is suitable to use when downsampling; we reshape then multiply by this matrix. assert num_channels % 2 == 0 - ans = OrthogonalLinear(num_channels) - inv_sqrt2 = 2 ** -0.5 - N = num_channels // 2 - eye = inv_sqrt2 * torch.eye(N) - # four blocks: (1/sqrt(2)) (1, 1; -1, 1) + ans = OrthogonalLinear(num_channels, penalty_scale=penalty_scale) + # want to initialize weight as: + # 1/sqrt(2) * M, where M is a block-diagonal matrix with 2x2 blocks [ 1 1; 1 -1 ] with torch.no_grad(): - ans.weight[:N, :N] = eye - ans.weight[:N, N:] = eye - ans.weight[N:, :N] = -eye - ans.weight[N:, N:] = eye - return ans + inv_sqrt2 = 2 ** -0.5 + ans.weight[:] = 0.0 + ans.weight[0::2, 0::2] = inv_sqrt2 + ans.weight[0::2, 1::2] = inv_sqrt2 + ans.weight[1::2, 0::2] = inv_sqrt2 + ans.weight[1::2, 1::2] = -inv_sqrt2 -def OrthogonalLinearUpsampling(num_channels: int): - # returns a parameterized nn.Linear that stays orthogonal, with a special initialization - # that is suitable to use when downsampling; we multiply by this matrix then reshape. - assert num_channels % 2 == 0 - ans = OrthogonalLinear(num_channels) - inv_sqrt2 = 2 ** -0.5 - N = num_channels // 2 - eye = inv_sqrt2 * torch.eye(N) - # four blocks: (1/sqrt(2)) (1, -1; 1, 1) - with torch.no_grad(): - ans.weight[:N, :N] = eye - ans.weight[:N, N:] = -eye - ans.weight[N:, :N] = eye - ans.weight[N:, N:] = eye return ans diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a4720944bd..2b0b21870c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -27,8 +27,7 @@ from encoder_interface import EncoderInterface from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - OrthogonalLinearUpsampling, - OrthogonalLinearDownsampling, + OrthogonalLinearSpecial, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, Balancer, @@ -154,13 +153,18 @@ def _to_tuple(x): # caution: some changes we made for this break the streaming, later we'll try to fix this. encoders_downsampling_factors = [ ] + # the following is basically heuristic; max(encoder_dim) would be OK also. + max_proj_dim = 2 * max(encoder_dim) + def set_downsample_factor(cur_downsample, ds): while cur_downsample < ds: # need to downsample - encoders.append(InvertibleDownsample(input_dim * cur_downsample)) + encoders.append(InvertibleDownsample(channels=input_dim * cur_downsample, + proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) cur_downsample *= 2 while cur_downsample > ds: - encoders.append(InvertibleUpsample(input_dim * cur_downsample)) + encoders.append(InvertibleUpsample(channels=input_dim * cur_downsample, + proj_dim=min(input_dim * cur_downsample, max_proj_dim))) cur_downsample //= 2 return cur_downsample @@ -1030,14 +1034,26 @@ def forward(self, src_orig: Tensor, src: Tensor): class InvertibleDownsample(torch.nn.Module): """ - Does downsampling in an invertible way, by a factor of two. + Does downsampling in an invertible way, by a factor of two. Projection is initialized + in a special way and enforced to be orthogonal. + + Args: + channels: the number of input channels; the num output channels will be twice this + proj_dim: the number of channels, after combining 2 frames by interpolating their channels + as [ a b a b, .. ] that will actually be projected; the rest are just copied. + proj_dim=2 * channels would mean all channels are projected in a learned way + causal: True for causal systems, only affects error messages as requires even + input num frames. + penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because + it may interact with the scale of the loss function, i.e. if the loss-function + scale is smaller you may want this to be smaller. """ def __init__( - self, channels: int, causal: bool = False, + self, channels: int, proj_dim: int, causal: bool = False, penalty_scale: float = 1000.0, ): super().__init__() - - self.proj = OrthogonalLinearDownsampling(channels * 2) + assert proj_dim <= channels * 2 + self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -1056,18 +1072,30 @@ def forward(self, src: Tensor) -> Tensor: src = torch.cat((src, src[-1:]), dim=0) seq_len += 1 - src = src.permute(1, 0, 2).reshape(batch_size, seq_len // 2, in_channels * 2) - src = self.proj(src) - src = src.permute(1, 0, 2) # (seq_len // 2, batch_size, in_channels * 2) + # the following will place each 2 frames of a particular channel right after + # each other as if they were two different channels. + src = torch.stack((src[0::2], src[1::2]), dim=-1) + src = src.reshape(seq_len // 2, batch_size, in_channels * 2) + + return src class InvertibleUpsample(torch.nn.Module): """ A very simple form of upsampling that is the inverse of InvertibleDownsampling. + Projection is initialized in a special way and enforced to be orthogonal. + + proj_dim: the number of channels that will actually be projected; the rest are just copied. + proj_dim=channels would mean all channels are projected in a learned way + penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because + it may interact with the scale of the loss function, i.e. if the loss-function + scale is smaller you may want this to be smaller. + """ - def __init__(self, channels: int): + def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0): super().__init__() - self.proj = OrthogonalLinearUpsampling(channels) + assert proj_dim <= channels + self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) def forward(self, src: Tensor) -> Tensor: """ @@ -1075,10 +1103,18 @@ def forward(self, src: Tensor) -> Tensor: Returns a tensor of shape ( (seq_len*2), batch_size, num_channels // 2) """ - src = self.proj(src) + proj_channels = self.proj.weight.shape[0] (seq_len, batch_size, in_channels) = src.shape - src = src.permute(1, 0, 2).reshape(batch_size, seq_len * 2, in_channels // 2) - src = src.permute(1, 0, 2) # (seq_len * 2, batch_size, in_channels // 2) + + if proj_channels < in_channels: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + + src = torch.stack((src[..., 0::2], src[..., 1::2]), + dim=1) # (seq_len, 2, batch_size, in_channels // 2) + src = src.reshape(seq_len * 2, batch_size, in_channels // 2) return src From 026dceee572c01ba815c46598f4a8dcbfb0fd47e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 16:04:24 +0800 Subject: [PATCH 0058/1191] Bug fixes --- egs/librispeech/ASR/zipformer/scaling.py | 3 +++ egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d9463b5fcc..99d2b3d2e2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -639,6 +639,9 @@ def OrthogonalLinearSpecial(num_channels: int, penalty_scale: float = 1000.0): ans.weight[0::2, 1::2] = inv_sqrt2 ans.weight[1::2, 0::2] = inv_sqrt2 ans.weight[1::2, 1::2] = -inv_sqrt2 + N = ans.weight.shape[0] + ans.weight *= (torch.arange(N)[:, None] // 2 == + torch.arange(N)[None, :] // 2) return ans diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2b0b21870c..455e272b27 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1076,8 +1076,12 @@ def forward(self, src: Tensor) -> Tensor: # each other as if they were two different channels. src = torch.stack((src[0::2], src[1::2]), dim=-1) src = src.reshape(seq_len // 2, batch_size, in_channels * 2) - - + proj_channels = self.proj.weight.shape[0] + if proj_channels < in_channels * 2: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) return src class InvertibleUpsample(torch.nn.Module): From 0cfbe7a6bf9f5b51981dca6968294805bec74737 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 17:06:58 +0800 Subject: [PATCH 0059/1191] Simplify how bias works in BiasNorm; have in_bias and out_bias. --- egs/librispeech/ASR/zipformer/scaling.py | 49 ++++++++---------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 99d2b3d2e2..95417d101d 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -365,7 +365,7 @@ def backward(ctx, x_grad, *args): class BiasNormFunction(torch.autograd.Function): # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # scales = (torch.mean(x ** 2 + log_eps.exp(), keepdim=True)) ** -0.5 * log_scale.exp() # return x * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for @@ -374,28 +374,21 @@ class BiasNormFunction(torch.autograd.Function): def forward( ctx, x: Tensor, - bias: Tensor, log_eps: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool, ) -> Tensor: - assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) scales = ( - torch.mean((x - bias) ** 2 + log_eps.exp(), + torch.mean(x ** 2 + log_eps.exp(), dim=channel_dim, keepdim=True) ** -0.5 ) * log_scale.exp() ans = x * scales ctx.save_for_backward( - ans.detach() if store_output_for_backprop else x, + x, scales.detach(), - bias.detach(), log_scale.detach(), log_eps.detach(), ) @@ -403,27 +396,22 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale, log_eps = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x + x, scales, log_scale, log_eps = ctx.saved_tensors x = x.detach() with torch.cuda.amp.autocast(enabled=False): x.requires_grad = True - bias.requires_grad = True log_scale.requires_grad = True log_eps.requires_grad = True with torch.enable_grad(): - # recompute scales from x, bias and log_scale. + # recompute scales from x, log_eps and log_scale. scales = ( - torch.mean((x - bias) ** 2 + log_eps.exp(), + torch.mean(x ** 2 + log_eps.exp(), dim=ctx.channel_dim, keepdim=True) ** -0.5 ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_eps.grad, log_scale.grad, None, None + return x.grad, log_eps.grad, log_scale.grad, None, None class BiasNorm(torch.nn.Module): @@ -452,10 +440,6 @@ class BiasNorm(torch.nn.Module): is learnable. log_scale_min: FloatLike, minimum allowed value of log_scale log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. """ def __init__( @@ -465,35 +449,32 @@ def __init__( log_scale: float = 1.0, log_scale_min: float = -1.5, log_scale_max: float = 1.5, - store_output_for_backprop: bool = False, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) + self.in_bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) + self.out_bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) self.log_eps = nn.Parameter(torch.tensor(0.0)) self.log_scale_min = log_scale_min self.log_scale_max = log_scale_max - self.store_output_for_backprop = store_output_for_backprop - def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels + x = x + self.in_bias + if torch.jit.is_scripting() or torch.jit.is_tracing(): channel_dim = self.channel_dim if channel_dim < 0: channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) scales = ( - torch.mean((x - bias) ** 2 + self.log_eps.exp(), + torch.mean(x ** 2 + self.log_eps.exp(), dim=channel_dim, keepdim=True) ** -0.5 ) * self.log_scale.exp() - return x * scales + return (x * scales) + self.out_bias log_scale = limit_param_value( self.log_scale, @@ -508,8 +489,8 @@ def forward(self, x: Tensor) -> Tensor: ) return BiasNormFunction.apply( - x, self.bias, self.log_eps, log_scale, self.channel_dim, self.store_output_for_backprop - ) + x, self.log_eps, log_scale, self.channel_dim + ) + self.out_bias From 1acceaa00455df16b7b38be9a1e6680f51468fe5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 22:59:28 +0800 Subject: [PATCH 0060/1191] update train.py for grad scaling and check frequency --- egs/librispeech/ASR/zipformer/train.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 745de767e4..7f456e6816 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1157,14 +1157,9 @@ def save_bad_model(suffix: str = ""): rank=rank, ) - if batch_idx % 25 == 0 and params.use_autocast: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. + if params.use_autocast: cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 2.0 or (cur_grad_scale < 8.0 and batch_idx % 100 == 0) or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") @@ -1172,10 +1167,19 @@ def save_bad_model(suffix: str = ""): if not params.inf_check: register_inf_check_hooks(model) logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: save_bad_model() raise_grad_scale_is_too_small_error(cur_grad_scale) + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 From 71237bb480c375a9ba00903b174c8eb0e7700ba7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 23:44:59 +0800 Subject: [PATCH 0061/1191] Introduce eps_noise in BiasNorm --- egs/librispeech/ASR/zipformer/scaling.py | 32 ++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 95417d101d..16d4122291 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -377,26 +377,33 @@ def forward( log_eps: Tensor, log_scale: Tensor, channel_dim: int, + log_eps_noise: float ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim + + noise_shape = list(x.shape) + noise_shape[channel_dim] = 1 + eps_noise = torch.randn_like(x) * log_eps_noise + scales = ( - torch.mean(x ** 2 + log_eps.exp(), + torch.mean(x ** 2 + log_eps.exp() + eps_noise, dim=channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() + ) * (0.5 * eps_noise + log_scale).exp() ans = x * scales ctx.save_for_backward( x, scales.detach(), log_scale.detach(), log_eps.detach(), + eps_noise.detach(), ) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, scales, log_scale, log_eps = ctx.saved_tensors + x, scales, log_scale, log_eps, eps_noise = ctx.saved_tensors x = x.detach() with torch.cuda.amp.autocast(enabled=False): x.requires_grad = True @@ -405,13 +412,13 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: with torch.enable_grad(): # recompute scales from x, log_eps and log_scale. scales = ( - torch.mean(x ** 2 + log_eps.exp(), + torch.mean(x ** 2 + log_eps.exp() + eps_noise, dim=ctx.channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() + ) * (0.5 * eps_noise + log_scale).exp() ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, log_eps.grad, log_scale.grad, None, None + return x.grad, log_eps.grad, log_scale.grad, None, None, None class BiasNorm(torch.nn.Module): @@ -458,6 +465,11 @@ def __init__( self.out_bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) self.log_eps = nn.Parameter(torch.tensor(0.0)) + # scale on noise we add to log_eps as part of a mechanism to encourage it to stay relatively large + # compared to x. + self.log_eps_noise = ScheduledFloat((0.0, 0.05), (20000.0, 0.02), default=0.0) + self.name = None + self.log_scale_min = log_scale_min self.log_scale_max = log_scale_max @@ -488,8 +500,14 @@ def forward(self, x: Tensor) -> Tensor: training=self.training, ) + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + in_bias_rms = (self.in_bias ** 2).mean().sqrt() + out_bias_rms = (self.out_bias ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, in_bias_rms={in_bias_rms}, out_bias_rms={out_bias_rms}, log_scale={self.log_scale.item()}, log_eps_noise={self.log_eps_noise.item()}") + return BiasNormFunction.apply( - x, self.log_eps, log_scale, self.channel_dim + x, self.log_eps, log_scale, self.channel_dim, float(self.log_eps_noise), ) + self.out_bias From 0ab3bc98a8ff58cc2223000e0ad812e3d88f5c98 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Jan 2025 23:58:43 +0800 Subject: [PATCH 0062/1191] Fix spelling error in messages --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 16d4122291..7455052c52 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -504,7 +504,7 @@ def forward(self, x: Tensor) -> Tensor: x_rms = (x ** 2).mean().sqrt() in_bias_rms = (self.in_bias ** 2).mean().sqrt() out_bias_rms = (self.out_bias ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, in_bias_rms={in_bias_rms}, out_bias_rms={out_bias_rms}, log_scale={self.log_scale.item()}, log_eps_noise={self.log_eps_noise.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, in_bias_rms={in_bias_rms}, out_bias_rms={out_bias_rms}, log_scale={self.log_scale.item()}, log_eps={self.log_eps.item()}") return BiasNormFunction.apply( x, self.log_eps, log_scale, self.channel_dim, float(self.log_eps_noise), From 7eb24df8dfdd16159e3216e9e8c682c5a4310cc3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Jan 2025 14:58:53 +0800 Subject: [PATCH 0063/1191] Take zipformer.py from 14conv --- egs/librispeech/ASR/zipformer/zipformer.py | 94 +++++----------------- 1 file changed, 22 insertions(+), 72 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 455e272b27..3ae4bd31dc 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -27,7 +27,8 @@ from encoder_interface import EncoderInterface from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - OrthogonalLinearSpecial, + OrthogonalLinearUpsampling, + OrthogonalLinearDownsampling, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, Balancer, @@ -153,18 +154,13 @@ def _to_tuple(x): # caution: some changes we made for this break the streaming, later we'll try to fix this. encoders_downsampling_factors = [ ] - # the following is basically heuristic; max(encoder_dim) would be OK also. - max_proj_dim = 2 * max(encoder_dim) - def set_downsample_factor(cur_downsample, ds): while cur_downsample < ds: # need to downsample - encoders.append(InvertibleDownsample(channels=input_dim * cur_downsample, - proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) + encoders.append(InvertibleDownsample(input_dim * cur_downsample)) cur_downsample *= 2 while cur_downsample > ds: - encoders.append(InvertibleUpsample(channels=input_dim * cur_downsample, - proj_dim=min(input_dim * cur_downsample, max_proj_dim))) + encoders.append(InvertibleUpsample(input_dim * cur_downsample)) cur_downsample //= 2 return cur_downsample @@ -489,7 +485,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 0.5), (10000.0, 0.1)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.5)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -584,9 +580,7 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - randomize_factor: float = 0.0, # will be 1/(probability with which we - # randomized this layer) if randomizing, - # else 0. + randomize: bool = False, # do the invertibility-encouraging randomization if True. ) -> Tensor: """ Pass the input through the encoder layer. @@ -606,7 +600,7 @@ def forward( """ ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - if torch.jit.is_scripting() or torch.jit.is_tracing() or not (self.training and randomize_factor != 0.0): + if torch.jit.is_scripting() or torch.jit.is_tracing() or not (randomize and self.training): return self.norm(ans) # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching @@ -624,7 +618,7 @@ def forward( diff_sqscale = (diff ** 2).mean(dim=2, keepdim=True) G = 0.1 # scale on the global-mean part of the random-noise scale. - scale = randomize_factor * float(self.randomize_scale) + scale = float(self.randomize_scale) with torch.cuda.amp.autocast(enabled=False): diff_scale = ((scale * G) * diff_sqscale.to(torch.float).mean() + (scale * (1. - G)) * diff_sqscale).sqrt() rand = torch.randn_like(src) * diff_scale @@ -876,13 +870,7 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - - randomize_proportion = 0.25 - L = len(self.layers) - # int(...) rounds down. we'll only randomize >= 2 layers if L >= 8. - num_randomize = max(1, int(0.5 + L * randomize_proportion)) - randomize_layer = [ True ] * num_randomize + [ False ] * (L - num_randomize) - random.shuffle(randomize_layer) + randomize_layer = random.randint(0, len(self.layers) - 1) for i, mod in enumerate(self.layers): src = mod( src, @@ -890,10 +878,8 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - randomize_factor=(L / num_randomize) if randomize_layer[i] else 0, + randomize=(i == randomize_layer), ) - # randomize_factor can be viewed as a simple version of an - # importance-sampling factor. if num_channels > layer_dim: src = torch.cat((src, bypass), dim=-1) @@ -1034,26 +1020,14 @@ def forward(self, src_orig: Tensor, src: Tensor): class InvertibleDownsample(torch.nn.Module): """ - Does downsampling in an invertible way, by a factor of two. Projection is initialized - in a special way and enforced to be orthogonal. - - Args: - channels: the number of input channels; the num output channels will be twice this - proj_dim: the number of channels, after combining 2 frames by interpolating their channels - as [ a b a b, .. ] that will actually be projected; the rest are just copied. - proj_dim=2 * channels would mean all channels are projected in a learned way - causal: True for causal systems, only affects error messages as requires even - input num frames. - penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because - it may interact with the scale of the loss function, i.e. if the loss-function - scale is smaller you may want this to be smaller. + Does downsampling in an invertible way, by a factor of two. """ def __init__( - self, channels: int, proj_dim: int, causal: bool = False, penalty_scale: float = 1000.0, + self, channels: int, causal: bool = False, ): super().__init__() - assert proj_dim <= channels * 2 - self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) + + self.proj = OrthogonalLinearDownsampling(channels * 2) self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -1072,34 +1046,18 @@ def forward(self, src: Tensor) -> Tensor: src = torch.cat((src, src[-1:]), dim=0) seq_len += 1 - # the following will place each 2 frames of a particular channel right after - # each other as if they were two different channels. - src = torch.stack((src[0::2], src[1::2]), dim=-1) - src = src.reshape(seq_len // 2, batch_size, in_channels * 2) - proj_channels = self.proj.weight.shape[0] - if proj_channels < in_channels * 2: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) + src = src.permute(1, 0, 2).reshape(batch_size, seq_len // 2, in_channels * 2) + src = self.proj(src) + src = src.permute(1, 0, 2) # (seq_len // 2, batch_size, in_channels * 2) return src class InvertibleUpsample(torch.nn.Module): """ A very simple form of upsampling that is the inverse of InvertibleDownsampling. - Projection is initialized in a special way and enforced to be orthogonal. - - proj_dim: the number of channels that will actually be projected; the rest are just copied. - proj_dim=channels would mean all channels are projected in a learned way - penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because - it may interact with the scale of the loss function, i.e. if the loss-function - scale is smaller you may want this to be smaller. - """ - def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0): + def __init__(self, channels: int): super().__init__() - assert proj_dim <= channels - self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) + self.proj = OrthogonalLinearUpsampling(channels) def forward(self, src: Tensor) -> Tensor: """ @@ -1107,18 +1065,10 @@ def forward(self, src: Tensor) -> Tensor: Returns a tensor of shape ( (seq_len*2), batch_size, num_channels // 2) """ - proj_channels = self.proj.weight.shape[0] + src = self.proj(src) (seq_len, batch_size, in_channels) = src.shape - - if proj_channels < in_channels: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - - src = torch.stack((src[..., 0::2], src[..., 1::2]), - dim=1) # (seq_len, 2, batch_size, in_channels // 2) - src = src.reshape(seq_len * 2, batch_size, in_channels // 2) + src = src.permute(1, 0, 2).reshape(batch_size, seq_len * 2, in_channels // 2) + src = src.permute(1, 0, 2) # (seq_len * 2, batch_size, in_channels // 2) return src From 45e7ee4a6eac3695ff6c03d1420783e1f9047e26 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Jan 2025 15:07:43 +0800 Subject: [PATCH 0064/1191] reset zipformer.py to 39conv but change randomize_scale/randomize_factor to match 14conv and set max_proj_dim to infinity --- egs/librispeech/ASR/zipformer/zipformer.py | 88 +++++++++++++++++----- 1 file changed, 68 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3ae4bd31dc..a5f328f885 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -27,8 +27,7 @@ from encoder_interface import EncoderInterface from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - OrthogonalLinearUpsampling, - OrthogonalLinearDownsampling, + OrthogonalLinearSpecial, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, Balancer, @@ -154,13 +153,18 @@ def _to_tuple(x): # caution: some changes we made for this break the streaming, later we'll try to fix this. encoders_downsampling_factors = [ ] + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + def set_downsample_factor(cur_downsample, ds): while cur_downsample < ds: # need to downsample - encoders.append(InvertibleDownsample(input_dim * cur_downsample)) + encoders.append(InvertibleDownsample(channels=input_dim * cur_downsample, + proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) cur_downsample *= 2 while cur_downsample > ds: - encoders.append(InvertibleUpsample(input_dim * cur_downsample)) + encoders.append(InvertibleUpsample(channels=input_dim * cur_downsample, + proj_dim=min(input_dim * cur_downsample, max_proj_dim))) cur_downsample //= 2 return cur_downsample @@ -580,7 +584,7 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - randomize: bool = False, # do the invertibility-encouraging randomization if True. + randomize: bool = False, ) -> Tensor: """ Pass the input through the encoder layer. @@ -600,7 +604,7 @@ def forward( """ ans = self.forward_internal(src, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - if torch.jit.is_scripting() or torch.jit.is_tracing() or not (randomize and self.training): + if torch.jit.is_scripting() or torch.jit.is_tracing() or not (self.training and randomize): return self.norm(ans) # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching @@ -870,7 +874,13 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - randomize_layer = random.randint(0, len(self.layers) - 1) + + randomize_proportion = 0.25 + L = len(self.layers) + # int(...) rounds down. we'll only randomize >= 2 layers if L >= 8. + num_randomize = max(1, int(0.5 + L * randomize_proportion)) + randomize_layer = [ True ] * num_randomize + [ False ] * (L - num_randomize) + random.shuffle(randomize_layer) for i, mod in enumerate(self.layers): src = mod( src, @@ -878,8 +888,10 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - randomize=(i == randomize_layer), + randomize=randomize_layer[i], ) + # randomize_factor can be viewed as a simple version of an + # importance-sampling factor. if num_channels > layer_dim: src = torch.cat((src, bypass), dim=-1) @@ -1020,14 +1032,26 @@ def forward(self, src_orig: Tensor, src: Tensor): class InvertibleDownsample(torch.nn.Module): """ - Does downsampling in an invertible way, by a factor of two. + Does downsampling in an invertible way, by a factor of two. Projection is initialized + in a special way and enforced to be orthogonal. + + Args: + channels: the number of input channels; the num output channels will be twice this + proj_dim: the number of channels, after combining 2 frames by interpolating their channels + as [ a b a b, .. ] that will actually be projected; the rest are just copied. + proj_dim=2 * channels would mean all channels are projected in a learned way + causal: True for causal systems, only affects error messages as requires even + input num frames. + penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because + it may interact with the scale of the loss function, i.e. if the loss-function + scale is smaller you may want this to be smaller. """ def __init__( - self, channels: int, causal: bool = False, + self, channels: int, proj_dim: int, causal: bool = False, penalty_scale: float = 1000.0, ): super().__init__() - - self.proj = OrthogonalLinearDownsampling(channels * 2) + assert proj_dim <= channels * 2 + self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -1046,18 +1070,34 @@ def forward(self, src: Tensor) -> Tensor: src = torch.cat((src, src[-1:]), dim=0) seq_len += 1 - src = src.permute(1, 0, 2).reshape(batch_size, seq_len // 2, in_channels * 2) - src = self.proj(src) - src = src.permute(1, 0, 2) # (seq_len // 2, batch_size, in_channels * 2) + # the following will place each 2 frames of a particular channel right after + # each other as if they were two different channels. + src = torch.stack((src[0::2], src[1::2]), dim=-1) + src = src.reshape(seq_len // 2, batch_size, in_channels * 2) + proj_channels = self.proj.weight.shape[0] + if proj_channels < in_channels * 2: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) return src class InvertibleUpsample(torch.nn.Module): """ A very simple form of upsampling that is the inverse of InvertibleDownsampling. + Projection is initialized in a special way and enforced to be orthogonal. + + proj_dim: the number of channels that will actually be projected; the rest are just copied. + proj_dim=channels would mean all channels are projected in a learned way + penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because + it may interact with the scale of the loss function, i.e. if the loss-function + scale is smaller you may want this to be smaller. + """ - def __init__(self, channels: int): + def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0): super().__init__() - self.proj = OrthogonalLinearUpsampling(channels) + assert proj_dim <= channels + self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) def forward(self, src: Tensor) -> Tensor: """ @@ -1065,10 +1105,18 @@ def forward(self, src: Tensor) -> Tensor: Returns a tensor of shape ( (seq_len*2), batch_size, num_channels // 2) """ - src = self.proj(src) + proj_channels = self.proj.weight.shape[0] (seq_len, batch_size, in_channels) = src.shape - src = src.permute(1, 0, 2).reshape(batch_size, seq_len * 2, in_channels // 2) - src = src.permute(1, 0, 2) # (seq_len * 2, batch_size, in_channels // 2) + + if proj_channels < in_channels: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + + src = torch.stack((src[..., 0::2], src[..., 1::2]), + dim=1) # (seq_len, 2, batch_size, in_channels // 2) + src = src.reshape(seq_len * 2, batch_size, in_channels // 2) return src From 8d0bedb069032f38e910527184b148c9f46d9c51 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Jan 2025 21:43:13 +0800 Subject: [PATCH 0065/1191] Fix bug RE shape of eps_noise in modified BiasNorm --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7455052c52..b37f00dc1b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -385,7 +385,7 @@ def forward( noise_shape = list(x.shape) noise_shape[channel_dim] = 1 - eps_noise = torch.randn_like(x) * log_eps_noise + eps_noise = torch.randn(*noise_shape, device=x.device, dtype=x.dtype) * log_eps_noise scales = ( torch.mean(x ** 2 + log_eps.exp() + eps_noise, From e186ab54ae17a8d4fc57a8bdf2fd06f9977a1746 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Jan 2025 18:17:32 +0800 Subject: [PATCH 0066/1191] Remove in_bias and out_bias, and fix another bug in the modified-BiasNorm formula, forgot to have the addition of eps_noise be inside the log. --- egs/librispeech/ASR/zipformer/scaling.py | 26 ++++++++++-------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b37f00dc1b..3ed5bba621 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -383,12 +383,15 @@ def forward( channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim - noise_shape = list(x.shape) - noise_shape[channel_dim] = 1 - eps_noise = torch.randn(*noise_shape, device=x.device, dtype=x.dtype) * log_eps_noise + if log_eps_noise != 0.0: + noise_shape = list(x.shape) + noise_shape[channel_dim] = 1 + eps_noise = torch.randn(*noise_shape, device=x.device, dtype=x.dtype) * log_eps_noise + else: + eps_noise = torch.zeros_like(log_eps) scales = ( - torch.mean(x ** 2 + log_eps.exp() + eps_noise, + torch.mean(x ** 2 + (log_eps + eps_noise).exp(), dim=channel_dim, keepdim=True) ** -0.5 ) * (0.5 * eps_noise + log_scale).exp() ans = x * scales @@ -412,7 +415,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: with torch.enable_grad(): # recompute scales from x, log_eps and log_scale. scales = ( - torch.mean(x ** 2 + log_eps.exp() + eps_noise, + torch.mean(x ** 2 + (log_eps + eps_noise).exp(), dim=ctx.channel_dim, keepdim=True) ** -0.5 ) * (0.5 * eps_noise + log_scale).exp() ans = x * scales @@ -461,8 +464,6 @@ def __init__( self.num_channels = num_channels self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.in_bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) - self.out_bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) self.log_eps = nn.Parameter(torch.tensor(0.0)) # scale on noise we add to log_eps as part of a mechanism to encourage it to stay relatively large @@ -476,8 +477,6 @@ def __init__( def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - x = x + self.in_bias - if torch.jit.is_scripting() or torch.jit.is_tracing(): channel_dim = self.channel_dim if channel_dim < 0: @@ -486,7 +485,7 @@ def forward(self, x: Tensor) -> Tensor: torch.mean(x ** 2 + self.log_eps.exp(), dim=channel_dim, keepdim=True) ** -0.5 ) * self.log_scale.exp() - return (x * scales) + self.out_bias + return (x * scales) log_scale = limit_param_value( self.log_scale, @@ -502,14 +501,11 @@ def forward(self, x: Tensor) -> Tensor: if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() - in_bias_rms = (self.in_bias ** 2).mean().sqrt() - out_bias_rms = (self.out_bias ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, in_bias_rms={in_bias_rms}, out_bias_rms={out_bias_rms}, log_scale={self.log_scale.item()}, log_eps={self.log_eps.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, log_scale={self.log_scale.item()}, log_eps={self.log_eps.item()}, (0.5*log_eps).exp()/x_rms={(0.5*self.log_eps).exp()/x_rms}") return BiasNormFunction.apply( x, self.log_eps, log_scale, self.channel_dim, float(self.log_eps_noise), - ) + self.out_bias - + ) From be2783ea9396540ba2a0caa624eea23bcb5150d9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Jan 2025 18:22:06 +0800 Subject: [PATCH 0067/1191] set final value of log_eps_noise to zero --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3ed5bba621..1592d49b20 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -468,7 +468,7 @@ def __init__( # scale on noise we add to log_eps as part of a mechanism to encourage it to stay relatively large # compared to x. - self.log_eps_noise = ScheduledFloat((0.0, 0.05), (20000.0, 0.02), default=0.0) + self.log_eps_noise = ScheduledFloat((0.0, 0.05), (20000.0, 0.0), default=0.0) self.name = None self.log_scale_min = log_scale_min From 8256c303688aeabaa4c3a6ff84a35f9aad501c26 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Jan 2025 00:55:42 +0800 Subject: [PATCH 0068/1191] Initialize second orthogonal projection with transpose, so they are inverses. --- egs/librispeech/ASR/zipformer/scaling.py | 8 +++++--- egs/librispeech/ASR/zipformer/zipformer.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1592d49b20..1a216af526 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -620,7 +620,9 @@ def forward(self, x: Tensor): logging.info(f"{self.name}: product_scale={1/alpha}, dim={weight.shape}, avg_err = {err} * {penalty_scale} = {err*penalty_scale}, ans-rms={ans_rms}") return ans -def OrthogonalLinearSpecial(num_channels: int, penalty_scale: float = 1000.0): +def OrthogonalLinearSpecial(num_channels: int, + penalty_scale: float = 1000.0, + transpose: bool = False): # returns a parameterized nn.Linear that stays orthogonal, with a special initialization # that is suitable to use when downsampling; we reshape then multiply by this matrix. assert num_channels % 2 == 0 @@ -632,8 +634,8 @@ def OrthogonalLinearSpecial(num_channels: int, penalty_scale: float = 1000.0): ans.weight[:] = 0.0 ans.weight[0::2, 0::2] = inv_sqrt2 ans.weight[0::2, 1::2] = inv_sqrt2 - ans.weight[1::2, 0::2] = inv_sqrt2 - ans.weight[1::2, 1::2] = -inv_sqrt2 + ans.weight[1::2, 0::2] = -inv_sqrt2 if transpose else inv_sqrt2 + ans.weight[1::2, 1::2] = inv_sqrt2 if transpose else -inv_sqrt2 N = ans.weight.shape[0] ans.weight *= (torch.arange(N)[:, None] // 2 == torch.arange(N)[None, :] // 2) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a5f328f885..2eaf038af9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1097,7 +1097,9 @@ class InvertibleUpsample(torch.nn.Module): def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0): super().__init__() assert proj_dim <= channels - self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) + self.proj = OrthogonalLinearSpecial(proj_dim, + penalty_scale=penalty_scale, + transpose=True) def forward(self, src: Tensor) -> Tensor: """ From 80ce9596fa31c8bd42a417aaf2cac2c485488631 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Jan 2025 15:01:23 +0800 Subject: [PATCH 0069/1191] Remove about half the modules in each zipformer layer --- egs/librispeech/ASR/zipformer/zipformer.py | 53 ++-------------------- 1 file changed, 3 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a5f328f885..9dc3311215 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -500,8 +500,6 @@ def __init__( self.bypass = BypassModule( embed_dim, ) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, @@ -514,45 +512,21 @@ def __init__( self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.feed_forward1 = FeedforwardModule( embed_dim, (feedforward_dim * 3) // 4, dropout ) self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule( - embed_dim, (feedforward_dim * 5) // 4, dropout - ) - - self.nonlin_attention = NonlinAttention( - embed_dim, hidden_channels=3 * embed_dim // 4 - ) - - self.conv_module1 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) - self.conv_module2 = ConvolutionModule( + self.conv_module = ConvolutionModule( embed_dim, cnn_module_kernel, causal=causal ) - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.norm = BiasNorm(embed_dim) - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) # balancer for output of feedforward2, prevent it from staying too # small. give this a very small probability, even at the start of @@ -567,16 +541,6 @@ def __init__( prob=0.05, ) - self.balancer_ff3 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - def forward( self, src: Tensor, @@ -671,26 +635,14 @@ def forward_internal( src = src + self.feed_forward1(src) - src = src + self.balancer_na(self.nonlin_attention(src, attn_weights[0:1])) - src = src + self.self_attn1(src, attn_weights) - src = src + self.conv_module1( + src = src + self.conv_module( src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask ) src = src + self.balancer_ff2(self.feed_forward2(src)) - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - src = src + self.self_attn2(src, attn_weights) - - src = src + self.conv_module2( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ) - src = src + self.balancer_ff3(self.feed_forward3(src)) - src = self.bypass(src_orig, src) return src @@ -753,6 +705,7 @@ def streaming_forward( src = src + self.feed_forward1(src) + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( src, attn_weights[0:1], From 17fd636f1b974e7fdb7f47331712d4d5c702c007 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Jan 2025 14:45:30 +0800 Subject: [PATCH 0070/1191] make part of Balancer run on OOM --- egs/librispeech/ASR/zipformer/scaling.py | 109 ++++++++++++++--------- 1 file changed, 67 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1a216af526..19dafd9d46 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -810,6 +810,51 @@ def streaming_forward( return x_chunk + x_causal, cache +def balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim): + # this was taken out of the Balancer backward function. + # returns modified version of x_grad. + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + return x_grad + + class BalancerFunction(torch.autograd.Function): @staticmethod def forward( @@ -835,50 +880,30 @@ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None] (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [i for i in range(x.ndim) if i != channel_dim] - uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = m_loss + r_loss - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = ( - (loss_grad**2) - .mean(dim=mean_dims, keepdim=True) - .sqrt() - .clamp(min=1.0e-20) - ) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) + x_grad = balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, + max_rms, grad_scale, channel_dim) except Exception as e: logging.info( - f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will balance a part of it." ) - + try: + # will take a piece of x_grad in this dimension. + dim_to_split = 0 if channel_dim != 0 else 1 + size = x.shape[dim_to_split] + + x_grad_part = balancer_backward_func(x.narrow(dim_to_split, 0, size // 4), + x_grad.narrow(dim_to_split, 0, size // 4), + min_mean, max_mean, min_rms, + max_rms, grad_scale, channel_dim) + del x # save memory + x_grad = torch.cat([x_grad_part, x_grad.narrow(dim_to_split, + size // 4, + size - size // 4)], + dim_to_split) + except Exception as e: + logging.info( + f"Caught exception second time in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -1972,7 +1997,7 @@ def isclose(a, b): assert isclose(x1.grad, x2.grad) def _test_orthogonal_linear(): - for t in (OrthogonalLinear, OrthogonalLinearUpsampling, OrthogonalLinearDownsampling): + for t in (OrthogonalLinear, OrthogonalLinearSpecial): m = t(128) m(torch.randn(30, 2, 128)) From 7a2174c70150815baf965eb3bdf1b0d856df7e65 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Jan 2025 20:26:52 +0800 Subject: [PATCH 0071/1191] simplify BiasNorm, taking things out of log space, and introduce max_scale. --- egs/librispeech/ASR/zipformer/scaling.py | 91 ++++++++++-------------- 1 file changed, 36 insertions(+), 55 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 19dafd9d46..cfe68b4f65 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -374,54 +374,46 @@ class BiasNormFunction(torch.autograd.Function): def forward( ctx, x: Tensor, - log_eps: Tensor, - log_scale: Tensor, + eps: Tensor, + max_scale: Tensor, + scale: Tensor, channel_dim: int, - log_eps_noise: float ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim - if log_eps_noise != 0.0: - noise_shape = list(x.shape) - noise_shape[channel_dim] = 1 - eps_noise = torch.randn(*noise_shape, device=x.device, dtype=x.dtype) * log_eps_noise - else: - eps_noise = torch.zeros_like(log_eps) - - scales = ( - torch.mean(x ** 2 + (log_eps + eps_noise).exp(), - dim=channel_dim, keepdim=True) ** -0.5 - ) * (0.5 * eps_noise + log_scale).exp() + x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) + scales = scale * torch.maximum(x_sq * max_scale, + x_sq + eps) ** -0.5 ans = x * scales ctx.save_for_backward( - x, + x.detach(), + eps.detach(), + max_scale.detach(), + scale.detach(), scales.detach(), - log_scale.detach(), - log_eps.detach(), - eps_noise.detach(), ) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, scales, log_scale, log_eps, eps_noise = ctx.saved_tensors - x = x.detach() + x, eps, max_scale, scale, scales = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): x.requires_grad = True - log_scale.requires_grad = True - log_eps.requires_grad = True + eps.requires_grad = True + max_scale.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): - # recompute scales from x, log_eps and log_scale. - scales = ( - torch.mean(x ** 2 + (log_eps + eps_noise).exp(), - dim=ctx.channel_dim, keepdim=True) ** -0.5 - ) * (0.5 * eps_noise + log_scale).exp() + + x_sq = torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + scales = scale * torch.maximum(x_sq * max_scale, + x_sq + eps) ** -0.5 ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, log_eps.grad, log_scale.grad, None, None, None + return x.grad, eps.grad, max_scale.grad, scale.grad, None class BiasNorm(torch.nn.Module): @@ -456,23 +448,16 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.log_eps = nn.Parameter(torch.tensor(0.0)) + self.scale = nn.Parameter(torch.tensor(1.0)) + self.eps = nn.Parameter(torch.tensor(1.0)) + self.max_scale = nn.Parameter(torch.tensor(2.0)) - # scale on noise we add to log_eps as part of a mechanism to encourage it to stay relatively large - # compared to x. - self.log_eps_noise = ScheduledFloat((0.0, 0.05), (20000.0, 0.0), default=0.0) self.name = None - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels @@ -481,30 +466,26 @@ def forward(self, x: Tensor) -> Tensor: channel_dim = self.channel_dim if channel_dim < 0: channel_dim += x.ndim - scales = ( - torch.mean(x ** 2 + self.log_eps.exp(), - dim=channel_dim, keepdim=True) ** -0.5 - ) * self.log_scale.exp() + x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) + scales = self.scale * torch.maximum(x_sq * self.max_scale, + x_sq + self.eps) ** -0.5 return (x * scales) - log_scale = limit_param_value( - self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training, - ) - log_eps = limit_param_value( - self.log_eps, - min=-5, max=5, # mainly to prevent infinities and zeroes - training=self.training, - ) + eps = limit_param_value( + self.eps, min=0.5, max=4.0, training=self.training) + + max_scale = limit_param_value( + self.max_scale, min=1.5, max=4.0, training=self.training) + + scale = limit_param_value( + self.scale, min=0.5, max=4.0, training=self.training) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, log_scale={self.log_scale.item()}, log_eps={self.log_eps.item()}, (0.5*log_eps).exp()/x_rms={(0.5*self.log_eps).exp()/x_rms}") + logging.info(f"name={self.name}: x_rms={x_rms}, eps={eps.item()}, max_scale={max_scale.item()}, scale={scale.item()}, sqrt(eps)/x_rms={eps.sqrt()/x_rms}") return BiasNormFunction.apply( - x, self.log_eps, log_scale, self.channel_dim, float(self.log_eps_noise), + x, eps, max_scale, scale, self.channel_dim, ) From 916a2e70220d034a10a90de31ca37dc8e90af684 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Jan 2025 20:53:06 +0800 Subject: [PATCH 0072/1191] small bug fixes, remove unnecessarily saved thing --- egs/librispeech/ASR/zipformer/scaling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index cfe68b4f65..7481ac973a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -392,13 +392,12 @@ def forward( eps.detach(), max_scale.detach(), scale.detach(), - scales.detach(), ) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, eps, max_scale, scale, scales = ctx.saved_tensors + x, eps, max_scale, scale = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): x.requires_grad = True eps.requires_grad = True @@ -452,7 +451,7 @@ def __init__( super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(1.0)) + self.scale = nn.Parameter(torch.tensor(2.0)) self.eps = nn.Parameter(torch.tensor(1.0)) self.max_scale = nn.Parameter(torch.tensor(2.0)) From b13824901bd97ba5e4b7a53ac01ee002d501375e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Jan 2025 21:16:12 +0800 Subject: [PATCH 0073/1191] adjust initial value and limits of max_scale --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7481ac973a..9c214e768a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -453,7 +453,7 @@ def __init__( self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(2.0)) self.eps = nn.Parameter(torch.tensor(1.0)) - self.max_scale = nn.Parameter(torch.tensor(2.0)) + self.max_scale = nn.Parameter(torch.tensor(1.2)) self.name = None @@ -474,7 +474,7 @@ def forward(self, x: Tensor) -> Tensor: self.eps, min=0.5, max=4.0, training=self.training) max_scale = limit_param_value( - self.max_scale, min=1.5, max=4.0, training=self.training) + self.max_scale, min=1.05, max=3.0, training=self.training) scale = limit_param_value( self.scale, min=0.5, max=4.0, training=self.training) From 3675e0c61201ee508ba08bc8e51f15582fe5056a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 11:35:53 +0800 Subject: [PATCH 0074/1191] change how BiasNorm works, take out the torch.maximum and max_scale and introduce a power. --- egs/librispeech/ASR/zipformer/scaling.py | 30 ++++++++++-------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9c214e768a..40fd059641 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -375,7 +375,7 @@ def forward( ctx, x: Tensor, eps: Tensor, - max_scale: Tensor, + power: Tensor, scale: Tensor, channel_dim: int, ) -> Tensor: @@ -384,35 +384,32 @@ def forward( ctx.channel_dim = channel_dim x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = scale * torch.maximum(x_sq * max_scale, - x_sq + eps) ** -0.5 + scales = scale * (x_sq ** power + eps) ** (-0.5 / power) ans = x * scales ctx.save_for_backward( x.detach(), eps.detach(), - max_scale.detach(), + power.detach(), scale.detach(), ) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, eps, max_scale, scale = ctx.saved_tensors + x, eps, power, scale = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): x.requires_grad = True eps.requires_grad = True - max_scale.requires_grad = True + power.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - x_sq = torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) - scales = scale * torch.maximum(x_sq * max_scale, - x_sq + eps) ** -0.5 + scales = scale * (x_sq ** power + eps) ** (-0.5 / power) ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, eps.grad, max_scale.grad, scale.grad, None + return x.grad, eps.grad, power.grad, scale.grad, None class BiasNorm(torch.nn.Module): @@ -453,7 +450,7 @@ def __init__( self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(2.0)) self.eps = nn.Parameter(torch.tensor(1.0)) - self.max_scale = nn.Parameter(torch.tensor(1.2)) + self.power = nn.Parameter(torch.tensor(1.0)) self.name = None @@ -466,25 +463,24 @@ def forward(self, x: Tensor) -> Tensor: if channel_dim < 0: channel_dim += x.ndim x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = self.scale * torch.maximum(x_sq * self.max_scale, - x_sq + self.eps) ** -0.5 + scales = self.scale * (x_sq ** self.power + self.eps) ** (-0.5 / self.power) return (x * scales) eps = limit_param_value( self.eps, min=0.5, max=4.0, training=self.training) - max_scale = limit_param_value( - self.max_scale, min=1.05, max=3.0, training=self.training) + power = limit_param_value( + self.power, min=0.9, max=3.0, training=self.training) scale = limit_param_value( self.scale, min=0.5, max=4.0, training=self.training) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, eps={eps.item()}, max_scale={max_scale.item()}, scale={scale.item()}, sqrt(eps)/x_rms={eps.sqrt()/x_rms}") + logging.info(f"name={self.name}: x_rms={x_rms}, power={power.item()}, eps**(1/power)={(eps ** (1/power))}, scale={scale.item()}, (eps**(0.5/power))/x_rms={(eps**(0.5/power))/x_rms}") return BiasNormFunction.apply( - x, eps, max_scale, scale, self.channel_dim, + x, eps, power, scale, self.channel_dim, ) From 6ca511cf1f7f140f8872fe76238edb4803389fcc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 12:17:05 +0800 Subject: [PATCH 0075/1191] Reduce min of power to 0.5 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 40fd059641..59a685590e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -470,7 +470,7 @@ def forward(self, x: Tensor) -> Tensor: self.eps, min=0.5, max=4.0, training=self.training) power = limit_param_value( - self.power, min=0.9, max=3.0, training=self.training) + self.power, min=0.5, max=3.0, training=self.training) scale = limit_param_value( self.scale, min=0.5, max=4.0, training=self.training) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 77713ac69d..e045fb1726 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -973,6 +973,7 @@ def _get_bypass_scale(self, batch_size: int): ans = torch.maximum(ans, mask.to(ans.dtype)) return ans + def forward(self, src_orig: Tensor, src: Tensor): """ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) From 24bb8a1eed51b3e32858dcbf136c343c695df636 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 14:40:09 +0800 Subject: [PATCH 0076/1191] Reduce min power from .5 to .25 and max from 4 to 2. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 59a685590e..c54ab1efb6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -473,7 +473,7 @@ def forward(self, x: Tensor) -> Tensor: self.power, min=0.5, max=3.0, training=self.training) scale = limit_param_value( - self.scale, min=0.5, max=4.0, training=self.training) + self.scale, min=0.25, max=2.0, training=self.training) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() From 18609f236560ea491fce7ac604ce4f8dbb79a398 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 15:13:12 +0800 Subject: [PATCH 0077/1191] Fix error where I had changed limits of scale not power --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c54ab1efb6..007952e4e6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -470,10 +470,10 @@ def forward(self, x: Tensor) -> Tensor: self.eps, min=0.5, max=4.0, training=self.training) power = limit_param_value( - self.power, min=0.5, max=3.0, training=self.training) + self.power, min=0.25, max=2.0, training=self.training) scale = limit_param_value( - self.scale, min=0.25, max=2.0, training=self.training) + self.scale, min=0.5, max=4.0, training=self.training) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() From d018cc4f3b6e36b5aa8d9d02b80c27026b21cc35 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 17:47:19 +0800 Subject: [PATCH 0078/1191] Clamp grads of parameters of BiasNormFunction --- egs/librispeech/ASR/zipformer/scaling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 007952e4e6..26cb05c63b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -409,7 +409,12 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, eps.grad, power.grad, scale.grad, None + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode. + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(eps.grad), c(power.grad), c(scale.grad), None class BiasNorm(torch.nn.Module): From d118afe159484211c47fa0c98a3fb2abec77f192 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 18:09:41 +0800 Subject: [PATCH 0079/1191] Take Balancer changes to run a part on OOM; adjust balancer settings to decrease limits and increase prob. --- egs/librispeech/ASR/zipformer/scaling.py | 109 +++++++++++++-------- egs/librispeech/ASR/zipformer/zipformer.py | 10 +- 2 files changed, 72 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1592d49b20..2d8d5b1b6a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -808,6 +808,51 @@ def streaming_forward( return x_chunk + x_causal, cache +def balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim): + # this was taken out of the Balancer backward function. + # returns modified version of x_grad. + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + return x_grad + + class BalancerFunction(torch.autograd.Function): @staticmethod def forward( @@ -833,50 +878,30 @@ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None] (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [i for i in range(x.ndim) if i != channel_dim] - uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = m_loss + r_loss - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = ( - (loss_grad**2) - .mean(dim=mean_dims, keepdim=True) - .sqrt() - .clamp(min=1.0e-20) - ) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) + x_grad = balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, + max_rms, grad_scale, channel_dim) except Exception as e: logging.info( - f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will balance a part of it." ) - + try: + # will take a piece of x_grad in this dimension. + dim_to_split = 0 if channel_dim != 0 else 1 + size = x.shape[dim_to_split] + + x_grad_part = balancer_backward_func(x.narrow(dim_to_split, 0, size // 4), + x_grad.narrow(dim_to_split, 0, size // 4), + min_mean, max_mean, min_rms, + max_rms, grad_scale, channel_dim) + del x # save memory + x_grad = torch.cat([x_grad_part, x_grad.narrow(dim_to_split, + size // 4, + size - size // 4)], + dim_to_split) + except Exception as e: + logging.info( + f"Caught exception second time in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -1970,7 +1995,7 @@ def isclose(a, b): assert isclose(x1.grad, x2.grad) def _test_orthogonal_linear(): - for t in (OrthogonalLinear, OrthogonalLinearUpsampling, OrthogonalLinearDownsampling): + for t in (OrthogonalLinear, OrthogonalLinearSpecial): m = t(128) m(torch.randn(30, 2, 128)) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a5f328f885..a040cd5a22 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -551,7 +551,7 @@ def __init__( min_positive=0.3, max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage + prob=0.1, ) # balancer for output of feedforward2, prevent it from staying too @@ -562,9 +562,9 @@ def __init__( channel_dim=-1, min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.05), default=0.0), max_abs=2.0, - prob=0.05, + prob=0.1, ) self.balancer_ff3 = Balancer( @@ -572,9 +572,9 @@ def __init__( channel_dim=-1, min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), max_abs=4.0, - prob=0.05, + prob=0.1, ) def forward( From d8e6cde1c775e47c57c18e9267867d39253fb3a4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 19:59:54 +0800 Subject: [PATCH 0080/1191] attempt to fix nan grad of power by using float32. --- egs/librispeech/ASR/zipformer/scaling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 26cb05c63b..1e2c38582f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -397,6 +397,8 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: x, eps, power, scale = ctx.saved_tensors + power, eps, scale = power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + power, eps, scale = power.detach(), eps.detach(), scale.detach() with torch.cuda.amp.autocast(enabled=False): x.requires_grad = True eps.requires_grad = True From e9f6ff2de1ea0356a45412e953e9887e9384da5e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 22:35:18 +0800 Subject: [PATCH 0081/1191] move some conversions inside no-autocast region --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1e2c38582f..c62084cc73 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -397,9 +397,9 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: x, eps, power, scale = ctx.saved_tensors - power, eps, scale = power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - power, eps, scale = power.detach(), eps.detach(), scale.detach() with torch.cuda.amp.autocast(enabled=False): + power, eps, scale = power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + power, eps, scale = power.detach(), eps.detach(), scale.detach() x.requires_grad = True eps.requires_grad = True power.requires_grad = True From 346b8986d5417b9ef700197eb641b817ac0264de Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jan 2025 22:39:02 +0800 Subject: [PATCH 0082/1191] Use float32 more thoroughly in backward of BiasNorm --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c62084cc73..a50598a8e3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -398,8 +398,8 @@ def forward( def backward(ctx, ans_grad: Tensor) -> Tensor: x, eps, power, scale = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): - power, eps, scale = power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - power, eps, scale = power.detach(), eps.detach(), scale.detach() + x, power, eps, scale = x.to(torch.float32), power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, power, eps, scale = x.detach(), power.detach(), eps.detach(), scale.detach() x.requires_grad = True eps.requires_grad = True power.requires_grad = True @@ -409,7 +409,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: x_sq = torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) scales = scale * (x_sq ** power + eps) ** (-0.5 / power) ans = x * scales - ans.backward(gradient=ans_grad) + ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): # this is to replace infinities that might be thrown up From 293aaf3eb409a361ae9649d7477c3a4fbdadfeca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Jan 2025 18:33:05 +0800 Subject: [PATCH 0083/1191] drafts --- egs/librispeech/ASR/zipformer/zipformer.py | 31 +++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e91d9472c0..d0aeafbd6a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -489,7 +489,7 @@ def __init__( dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.5)), + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.75)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -527,18 +527,6 @@ def __init__( self.norm = BiasNorm(embed_dim) - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.05), default=0.0), - max_abs=2.0, - prob=0.1, - ) def forward( self, @@ -583,12 +571,19 @@ def forward( x1 = xt + (ans_t - xt) * (1. - t) diff = (x1 - ans) / (t - t**2) - diff_sqscale = (diff ** 2).mean(dim=2, keepdim=True) - G = 0.1 # scale on the global-mean part of the random-noise scale. + diff_scale = (diff ** 2).mean(dim=2, keepdim=True).sqrt() + ans_scale = (ans ** 2).mean(dim=2, keepdim=True).sqrt() scale = float(self.randomize_scale) with torch.cuda.amp.autocast(enabled=False): - diff_scale = ((scale * G) * diff_sqscale.to(torch.float).mean() + (scale * (1. - G)) * diff_sqscale).sqrt() - rand = torch.randn_like(src) * diff_scale + # float(self.randomize_scale) * diff_scale is the main term that penalizes deviations from + # linear "flow". 0.01 is a constant term that will motivate the network to increase the + # dynamic range of the activations. 0.1 * (ans_scale - 1).relu() is a term that + # will start adding a penalty if any frames have rms value greater than 1, so in combination + # with the 0.01 constant term this should keep the activations just under 1; this will + # also help to discourage 'outlier' frames that have larger-than-normal norm. + noise_scale = float(self.randomize_scale) * diff_scale + 0.01 + 0.1 * (ans_scale - 1).relu() + + rand = torch.randn_like(src) * noise_scale if random.random() < 0.01 or __name__ == '__main__': # logging output ans_scale = (ans ** 2).mean().sqrt() @@ -640,7 +635,7 @@ def forward_internal( src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask ) - src = src + self.balancer_ff2(self.feed_forward2(src)) + src = src + self.feed_forward2(src) src = self.bypass(src_orig, src) From 9d42e8e197023debe73c86e4e165f596ed7bd98d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Jan 2025 18:59:04 +0800 Subject: [PATCH 0084/1191] Change formula for noise_scale to encourage feature rms to be near 1.0; make BiasNorm eps be learned in log space; remove balancer_ff2. --- egs/librispeech/ASR/zipformer/scaling.py | 30 +++++++++++----------- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++++++------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a50598a8e3..8dfe7bb95f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -374,7 +374,7 @@ class BiasNormFunction(torch.autograd.Function): def forward( ctx, x: Tensor, - eps: Tensor, + log_eps: Tensor, power: Tensor, scale: Tensor, channel_dim: int, @@ -384,11 +384,11 @@ def forward( ctx.channel_dim = channel_dim x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = scale * (x_sq ** power + eps) ** (-0.5 / power) + scales = scale * (x_sq ** power + log_eps.exp()) ** (-0.5 / power) ans = x * scales ctx.save_for_backward( x.detach(), - eps.detach(), + log_eps.detach(), power.detach(), scale.detach(), ) @@ -396,18 +396,18 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, eps, power, scale = ctx.saved_tensors + x, log_eps, power, scale = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): - x, power, eps, scale = x.to(torch.float32), power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - x, power, eps, scale = x.detach(), power.detach(), eps.detach(), scale.detach() + x, power, log_eps, scale = x.to(torch.float32), power.to(torch.float32), log_eps.to(torch.float32), scale.to(torch.float32) + x, power, log_eps, scale = x.detach(), power.detach(), log_eps.detach(), scale.detach() x.requires_grad = True - eps.requires_grad = True + log_eps.requires_grad = True power.requires_grad = True scale.requires_grad = True with torch.enable_grad(): x_sq = torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) - scales = scale * (x_sq ** power + eps) ** (-0.5 / power) + scales = scale * (x_sq ** power + log_eps.exp()) ** (-0.5 / power) ans = x * scales ans.backward(gradient=ans_grad.to(torch.float32)) @@ -416,7 +416,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(eps.grad), c(power.grad), c(scale.grad), None + return x.grad, c(log_eps.grad), c(power.grad), c(scale.grad), None class BiasNorm(torch.nn.Module): @@ -456,7 +456,7 @@ def __init__( self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(2.0)) - self.eps = nn.Parameter(torch.tensor(1.0)) + self.log_eps = nn.Parameter(torch.tensor(0.0)) self.power = nn.Parameter(torch.tensor(1.0)) self.name = None @@ -470,11 +470,11 @@ def forward(self, x: Tensor) -> Tensor: if channel_dim < 0: channel_dim += x.ndim x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = self.scale * (x_sq ** self.power + self.eps) ** (-0.5 / self.power) + scales = self.scale * (x_sq ** self.power + self.log_eps.exp()) ** (-0.5 / self.power) return (x * scales) - eps = limit_param_value( - self.eps, min=0.5, max=4.0, training=self.training) + log_eps = limit_param_value( + self.log_eps, min=-3.0, max=3.0, training=self.training) power = limit_param_value( self.power, min=0.25, max=2.0, training=self.training) @@ -484,10 +484,10 @@ def forward(self, x: Tensor) -> Tensor: if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, power={power.item()}, eps**(1/power)={(eps ** (1/power))}, scale={scale.item()}, (eps**(0.5/power))/x_rms={(eps**(0.5/power))/x_rms}") + logging.info(f"name={self.name}: x_rms={x_rms}, power={power.item()}, eps={log_eps.exp()}, eps**(1/power)={(log_eps.exp() ** (1/power))}, scale={scale.item()}, (eps**(0.5/power))/x_rms={(log_eps.exp()**(0.5/power))/x_rms}") return BiasNormFunction.apply( - x, eps, power, scale, self.channel_dim, + x, log_eps, power, scale, self.channel_dim, ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d0aeafbd6a..0c701c4ab7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -572,23 +572,22 @@ def forward( diff = (x1 - ans) / (t - t**2) diff_scale = (diff ** 2).mean(dim=2, keepdim=True).sqrt() - ans_scale = (ans ** 2).mean(dim=2, keepdim=True).sqrt() + ans_scale_sq = (ans ** 2).mean(dim=2, keepdim=True) + ans_scale = ans_scale_sq.sqrt() + scale = float(self.randomize_scale) with torch.cuda.amp.autocast(enabled=False): # float(self.randomize_scale) * diff_scale is the main term that penalizes deviations from - # linear "flow". 0.01 is a constant term that will motivate the network to increase the - # dynamic range of the activations. 0.1 * (ans_scale - 1).relu() is a term that - # will start adding a penalty if any frames have rms value greater than 1, so in combination - # with the 0.01 constant term this should keep the activations just under 1; this will - # also help to discourage 'outlier' frames that have larger-than-normal norm. - noise_scale = float(self.randomize_scale) * diff_scale + 0.01 + 0.1 * (ans_scale - 1).relu() + # linear "flow". + # 0.005 ( 1 + ans_scale_sq) is supposed to encourage the rms of embedding vectors to be + # about 1. + noise_scale = float(self.randomize_scale) * diff_scale + 0.005 * (1 + ans_scale_sq) rand = torch.randn_like(src) * noise_scale if random.random() < 0.01 or __name__ == '__main__': # logging output - ans_scale = (ans ** 2).mean().sqrt() - vt_scale = ((ans - src) ** 2).mean().sqrt() - logging.info(f"name={self.name}: ans_scale={ans_scale}, vt_scale={vt_scale}, diff-scale={diff_sqscale.mean().sqrt()}") + vt_scale = ((ans - src) ** 2).mean(dim=2, keepdim=True).sqrt().mean() + logging.info(f"name={self.name}: ans_scale={ans_scale.mean()}, vt_scale={vt_scale}, diff-scale={diff_scale.mean()}") return self.norm(ans + rand) From 816089df0a0871160108ae2d8ac509553f663b77 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Jan 2025 19:19:30 +0800 Subject: [PATCH 0085/1191] Print noise_scale --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0c701c4ab7..6c7a1a49ce 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -587,7 +587,7 @@ def forward( if random.random() < 0.01 or __name__ == '__main__': # logging output vt_scale = ((ans - src) ** 2).mean(dim=2, keepdim=True).sqrt().mean() - logging.info(f"name={self.name}: ans_scale={ans_scale.mean()}, vt_scale={vt_scale}, diff-scale={diff_scale.mean()}") + logging.info(f"name={self.name}: ans_scale={ans_scale.mean()}, vt_scale={vt_scale}, diff-scale={diff_scale.mean()}, noise-scale={noise_scale.mean()}") return self.norm(ans + rand) From 5cd32124aa60a67b2340842fdd7d297d1154e6f2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Jan 2025 21:20:21 +0800 Subject: [PATCH 0086/1191] cosmetic fix --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6c7a1a49ce..15d30f5b77 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -581,7 +581,7 @@ def forward( # linear "flow". # 0.005 ( 1 + ans_scale_sq) is supposed to encourage the rms of embedding vectors to be # about 1. - noise_scale = float(self.randomize_scale) * diff_scale + 0.005 * (1 + ans_scale_sq) + noise_scale = scale * diff_scale + 0.005 * (1 + ans_scale_sq) rand = torch.randn_like(src) * noise_scale if random.random() < 0.01 or __name__ == '__main__': From d203f6d1b81c67100a28bcc3bf6ce8b6a8e09853 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Jan 2025 13:47:51 +0800 Subject: [PATCH 0087/1191] Add Balancer, as in 80conv --- egs/librispeech/ASR/zipformer/zipformer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 15d30f5b77..966d00cb49 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -523,6 +523,14 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) + self.balancer = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.5, + max_positive=0.7, + min_abs=0.5, + max_abs=10.0, + ) self.norm = BiasNorm(embed_dim) @@ -638,6 +646,8 @@ def forward_internal( src = self.bypass(src_orig, src) + src = self.balancer(src) + return src def streaming_forward( From 51b3aeed5034fcca5639c708fd52b0fc6514c74e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Jan 2025 12:06:59 +0800 Subject: [PATCH 0088/1191] Add out_balancer to subsampling.py --- egs/librispeech/ASR/zipformer/subsampling.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index b2f769d3f6..10266ff351 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -271,6 +271,19 @@ def __init__( self.layer3_channels = layer3_channels self.out = nn.Linear(self.out_width * layer3_channels, out_channels) + + # we don't want very large values here as it could lead to nan's in the forward pass in fp16; + # this happened in some experiments. that's the reason why this Balancer was inroduced, i.e + # the max_abs is the most important limit. + self.out_balancer = Balancer( + out_channels, + channel_dim=-1, + min_positive=0.2, + max_positive=0.8, + min_abs=0.5, + max_abs=5.0, + ) + # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together # many copies of this extra gradient term. @@ -316,6 +329,7 @@ def forward( # now x: (N, (T-7)//2, out_width * layer3_channels)) x = self.out(x) + x = self.out_balancer(x) # Now x is of shape (N, (T-7)//2, odim) x = self.out_whiten(x) x = self.out_norm(x) From aaf8c89a221513153704c74864646944c088382a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 13:15:58 +0800 Subject: [PATCH 0089/1191] Balance scale of features differently, with new ScaleBalancer, in zipformer layers and frontend. --- egs/librispeech/ASR/zipformer/scaling.py | 33 ++++++++++++++++++++ egs/librispeech/ASR/zipformer/subsampling.py | 11 ++----- egs/librispeech/ASR/zipformer/zipformer.py | 15 ++------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8dfe7bb95f..82f24a8abc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1006,6 +1006,39 @@ def _approx_inverse_erf(x): return _no_op(x) + +class ScaleBalancer(torch.nn.Module): + """ + Tries to make the rms value of the features around 1, using + strategically added noise. This is not per dimension, but globally. + Assumes channel dim is -1 and the input shape has >1 dimension. + """ + + def __init__(self): + super().__init__() + self.noise_scale = 0.1 + + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return _no_op(x) + + x_shape = list(x.shape) + x_shape[-1] = 1 + + # we estimate the rms value of x from about 1 in 20 embedding vectors, or at most about 500 + # embedding vectors. This is to prevent the grads propagated this way from being so small + # that when added to the main gradient term they make no difference, in fp16. + r = torch.rand(*x_shape, device=x.device) + prob = 0.05 + mask = (r < prob).to(x.dtype) + x_sq = (x ** 2).sum(dim=-1, keepdim=True) + x_sq_mean = (x_sq * mask).mean() / mask.mean().clamp_(min=1.0) + + noise = ((self.noise_scale * (1 + x_sq_mean)) * mask) * torch.randn_like(x) + return x + noise + + def penalize_abs_values_gt( x: Tensor, limit: float, penalty: float, name: str = None ) -> Tensor: diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 10266ff351..58b9c9e8b1 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -22,6 +22,7 @@ import torch from scaling import ( Balancer, + ScaleBalancer, BiasNorm, Dropout3, FloatLike, @@ -84,14 +85,8 @@ def __init__( initial_scale=0.01, ) - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) + self.out_balancer = ScaleBalancer() + self.out_whiten = Whiten( num_groups=1, whitening_limit=5.0, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 966d00cb49..5a75c91b82 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -31,6 +31,7 @@ ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, Balancer, + ScaleBalancer, BiasNorm, ChunkCausalDepthwiseConv1d, Dropout2, @@ -523,19 +524,11 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) - self.balancer = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.5, - max_positive=0.7, - min_abs=0.5, - max_abs=10.0, - ) + self.balancer = ScaleBalancer() self.norm = BiasNorm(embed_dim) - def forward( self, src: Tensor, @@ -587,9 +580,7 @@ def forward( with torch.cuda.amp.autocast(enabled=False): # float(self.randomize_scale) * diff_scale is the main term that penalizes deviations from # linear "flow". - # 0.005 ( 1 + ans_scale_sq) is supposed to encourage the rms of embedding vectors to be - # about 1. - noise_scale = scale * diff_scale + 0.005 * (1 + ans_scale_sq) + noise_scale = scale * diff_scale rand = torch.randn_like(src) * noise_scale if random.random() < 0.01 or __name__ == '__main__': From 012d6b1a99e5f5eb991bdc122a9e7237646fb7c9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 13:49:47 +0800 Subject: [PATCH 0090/1191] Initialize proj to random non-orthogonal matrix --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- egs/librispeech/ASR/zipformer/zipformer.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8dfe7bb95f..cba0fb8661 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -568,9 +568,9 @@ def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): self.min_product_scale = 0.01 self.name = None # will be set from training loop. for printing penalty. - # by default, initialize to the identity. with torch.no_grad(): - self.weight[:] = torch.eye(num_channels) + # this is not orthogonal but should quickly become so. + self.weight[:] = torch.randn(num_channels, num_channels) * (num_channels ** -0.5) def forward(self, x: Tensor): ans = nn.functional.linear(x, self.weight, self.bias) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 966d00cb49..657a7d4035 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -28,6 +28,7 @@ from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. OrthogonalLinearSpecial, + OrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, Balancer, @@ -1008,7 +1009,7 @@ def __init__( ): super().__init__() assert proj_dim <= channels * 2 - self.proj = OrthogonalLinearSpecial(proj_dim, penalty_scale=penalty_scale) + self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -1054,9 +1055,7 @@ class InvertibleUpsample(torch.nn.Module): def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0): super().__init__() assert proj_dim <= channels - self.proj = OrthogonalLinearSpecial(proj_dim, - penalty_scale=penalty_scale, - transpose=True) + self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) def forward(self, src: Tensor) -> Tensor: """ From 712e71837ce57fd1c3ca535806b0befe2b0f3645 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 14:07:40 +0800 Subject: [PATCH 0091/1191] Reduce randomize_proportion from .25 to .1 to save memory --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 657a7d4035..2ed3b781a8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -832,7 +832,7 @@ def forward( src, bypass = src[..., :layer_dim], src[..., layer_dim:] - randomize_proportion = 0.25 + randomize_proportion = 0.1 L = len(self.layers) # int(...) rounds down. we'll only randomize >= 2 layers if L >= 8. num_randomize = max(1, int(0.5 + L * randomize_proportion)) From df850fd72769c9b78776ce218bdb2e03bee45721 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 14:24:35 +0800 Subject: [PATCH 0092/1191] Fix mask mean, previously scaled up noise too much --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 82f24a8abc..653c49dacf 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1033,7 +1033,7 @@ def forward(self, x: Tensor) -> Tensor: prob = 0.05 mask = (r < prob).to(x.dtype) x_sq = (x ** 2).sum(dim=-1, keepdim=True) - x_sq_mean = (x_sq * mask).mean() / mask.mean().clamp_(min=1.0) + x_sq_mean = (x_sq * mask).mean() / mask.mean().clamp_(min=0.5*prob) noise = ((self.noise_scale * (1 + x_sq_mean)) * mask) * torch.randn_like(x) return x + noise From 28330712dd5a5d3c1a06b480a52fe90379edc57c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 16:17:48 +0800 Subject: [PATCH 0093/1191] Fix error in formula --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 653c49dacf..5d19cda0ba 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1032,7 +1032,7 @@ def forward(self, x: Tensor) -> Tensor: r = torch.rand(*x_shape, device=x.device) prob = 0.05 mask = (r < prob).to(x.dtype) - x_sq = (x ** 2).sum(dim=-1, keepdim=True) + x_sq = (x ** 2).mean(dim=-1, keepdim=True) x_sq_mean = (x_sq * mask).mean() / mask.mean().clamp_(min=0.5*prob) noise = ((self.noise_scale * (1 + x_sq_mean)) * mask) * torch.randn_like(x) From ae9f828a91c809d47eec8258a9d616546be34521 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 21:25:17 +0800 Subject: [PATCH 0094/1191] fix return params; decrease prob to 0.01 and increase noise_scale to 0.2, in ScaleBalancer. --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- egs/librispeech/ASR/zipformer/train.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5d19cda0ba..0961fe7002 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1016,7 +1016,7 @@ class ScaleBalancer(torch.nn.Module): def __init__(self): super().__init__() - self.noise_scale = 0.1 + self.noise_scale = 0.2 def forward(self, x: Tensor) -> Tensor: @@ -1030,10 +1030,10 @@ def forward(self, x: Tensor) -> Tensor: # embedding vectors. This is to prevent the grads propagated this way from being so small # that when added to the main gradient term they make no difference, in fp16. r = torch.rand(*x_shape, device=x.device) - prob = 0.05 + prob = 0.01 mask = (r < prob).to(x.dtype) x_sq = (x ** 2).mean(dim=-1, keepdim=True) - x_sq_mean = (x_sq * mask).mean() / mask.mean().clamp_(min=0.5*prob) + x_sq_mean = (x_sq * mask).mean() / prob noise = ((self.noise_scale * (1 + x_sq_mean)) * mask) * torch.randn_like(x) return x + noise diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7f456e6816..f96c051a15 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -811,7 +811,7 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - return saved_params + return params def save_checkpoint( From 3f97fc98841e16b1fa648e64d619060428a9d9da Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Jan 2025 22:23:01 +0800 Subject: [PATCH 0095/1191] do the change I originally intended in subsampling.py where the out_balancer of Conv2dSubsmpling is the ScaleBalancer. --- egs/librispeech/ASR/zipformer/subsampling.py | 21 +++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 58b9c9e8b1..0d6c2d813d 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -85,7 +85,14 @@ def __init__( initial_scale=0.01, ) - self.out_balancer = ScaleBalancer() + self.out_balancer = Balancer( + channels, + channel_dim=1, + min_positive=0.4, + max_positive=0.6, + min_abs=1.0, + max_abs=6.0, + ) self.out_whiten = Whiten( num_groups=1, @@ -267,17 +274,7 @@ def __init__( self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - # we don't want very large values here as it could lead to nan's in the forward pass in fp16; - # this happened in some experiments. that's the reason why this Balancer was inroduced, i.e - # the max_abs is the most important limit. - self.out_balancer = Balancer( - out_channels, - channel_dim=-1, - min_positive=0.2, - max_positive=0.8, - min_abs=0.5, - max_abs=5.0, - ) + self.out_balancer = ScaleBalancer() # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together From aabc945a983a515a205d348ae131340ed11ff262 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 31 Jan 2025 11:02:20 +0800 Subject: [PATCH 0096/1191] Revert change about saved_params vs params --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f96c051a15..7f456e6816 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -811,7 +811,7 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - return params + return saved_params def save_checkpoint( From 46e67b9f76fc604efbb7e35930d47ad0782b3642 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Feb 2025 12:31:53 +0800 Subject: [PATCH 0097/1191] Remove randomization --- egs/librispeech/ASR/zipformer/zipformer.py | 75 +--------------------- 1 file changed, 3 insertions(+), 72 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 320778cf0b..9080e99da8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -537,69 +537,6 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - randomize: bool = False, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - randomize: if true use a form of randomization/dropout that encourages invertibility. - - Returns: - A tensor which has the same shape as src - """ - ans = self.forward_internal(src, pos_emb, chunk_size, - attn_mask, src_key_padding_mask) - if torch.jit.is_scripting() or torch.jit.is_tracing() or not (self.training and randomize): - return self.norm(ans) - - # we view the input 'src' as x0 and the answer 'ans' as x1, like in a flow-matching - # situation, and we compute an alternative version of x1 (called "x1" in the code) - # that is computed as two steps. We then amplify the difference between "ans" and - # that alternative version of x1, and multiply it by random noise. - - (seq_len, batch_size, emb_dim) = src.shape - t = torch.empty(batch_size, 1, device=src.device).uniform_(0.1, 0.9) - xt = src + (ans - src) * t - # ans_t is the network evaluated at t. it's interpreted as xt + vt. - ans_t = self.forward_internal(xt, pos_emb, chunk_size, attn_mask, src_key_padding_mask) - x1 = xt + (ans_t - xt) * (1. - t) - diff = (x1 - ans) / (t - t**2) - - diff_scale = (diff ** 2).mean(dim=2, keepdim=True).sqrt() - ans_scale_sq = (ans ** 2).mean(dim=2, keepdim=True) - ans_scale = ans_scale_sq.sqrt() - - scale = float(self.randomize_scale) - with torch.cuda.amp.autocast(enabled=False): - # float(self.randomize_scale) * diff_scale is the main term that penalizes deviations from - # linear "flow". - noise_scale = scale * diff_scale - - rand = torch.randn_like(src) * noise_scale - if random.random() < 0.01 or __name__ == '__main__': - # logging output - vt_scale = ((ans - src) ** 2).mean(dim=2, keepdim=True).sqrt().mean() - logging.info(f"name={self.name}: ans_scale={ans_scale.mean()}, vt_scale={vt_scale}, diff-scale={diff_scale.mean()}, noise-scale={noise_scale.mean()}") - - return self.norm(ans + rand) - - - - def forward_internal( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Pass the input through the encoder layer. @@ -640,7 +577,7 @@ def forward_internal( src = self.balancer(src) - return src + return self.norm(src) def streaming_forward( self, @@ -750,6 +687,8 @@ def streaming_forward( src = self.bypass(src_orig, src) + src = self.norm(src) + return ( src, cached_key, @@ -822,13 +761,6 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - - randomize_proportion = 0.1 - L = len(self.layers) - # int(...) rounds down. we'll only randomize >= 2 layers if L >= 8. - num_randomize = max(1, int(0.5 + L * randomize_proportion)) - randomize_layer = [ True ] * num_randomize + [ False ] * (L - num_randomize) - random.shuffle(randomize_layer) for i, mod in enumerate(self.layers): src = mod( src, @@ -836,7 +768,6 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - randomize=randomize_layer[i], ) # randomize_factor can be viewed as a simple version of an # importance-sampling factor. From 0f00728cad8ee69a53f12103a6871b1e1bbc28ab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Feb 2025 12:29:10 +0800 Subject: [PATCH 0098/1191] Add noise to bypass of zipformer encoder. --- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9080e99da8..7ca3c3a7fd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -721,6 +721,7 @@ def __init__( num_layers: int, pos_dim: int, dropout: float, + bypass_noise: FloatLike = ScheduledFloat((0.0, 0.5), (10000.0, 0.05)), ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( @@ -731,6 +732,7 @@ def __init__( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.bypass_noise = copy.deepcopy(bypass_noise) def forward( self, @@ -761,6 +763,9 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] + if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): + bypass = self._add_noise_to_bypass(bypass) + for i, mod in enumerate(self.layers): src = mod( src, @@ -777,6 +782,20 @@ def forward( return src + def _add_noise_to_bypass(self, x: Tensor): + bypass_scale = float(self.bypass_noise) + # a simpler way to set the noise scale would be to use + # bypass_scale * (x ** 2).mean().sqrt(). Using + # 0.5 * ((x ** 2).mean() + 1.0) instead gives the same answer when the rms + # is 1.0, and a larger answer elsewhere, so it encourages the rms of + # x to be about 1.0. Using .mean(dim=-1, keepdim=True) instead of .mean(), i.e. per-frame + # magnitude, helps to keep the gradients more concentrated which, in fp16 + # training, should reduce certain biases caused by roundoff which otherwise + # tend to lead the embeddings to get smaller in scale. + noise_scale = (0.5 * bypass_scale) * ((x ** 2).mean(dim=-1, keepdim=True) + 1.0) + return x + torch.randn_like(x) * noise_scale + + def streaming_forward( self, src: Tensor, From 8865cadb9ebad3ca6adcfbeb1c02d5a30225ed12 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Feb 2025 15:24:23 +0800 Subject: [PATCH 0099/1191] Different bypass_noise schedule that starts and ends at zero. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7ca3c3a7fd..ad0866aa5b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -721,7 +721,7 @@ def __init__( num_layers: int, pos_dim: int, dropout: float, - bypass_noise: FloatLike = ScheduledFloat((0.0, 0.5), (10000.0, 0.05)), + bypass_noise: FloatLike = ScheduledFloat((0.0, 0.0), (4000.0, 0.2), (8000.0, 0.0)), ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( From ffc78421e7f9d4045d94a62f39c9c262b2b06d10 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Feb 2025 18:58:01 +0800 Subject: [PATCH 0100/1191] implement random rotation of diff-between-pairs-of-frames early in training, to try to encourage nice projections --- egs/librispeech/ASR/zipformer/zipformer.py | 27 ++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ad0866aa5b..78f65c47ca 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -993,10 +993,13 @@ class InvertibleUpsample(torch.nn.Module): scale is smaller you may want this to be smaller. """ - def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0): + def __init__(self, channels: int, proj_dim: int, + penalty_scale: float = 1000.0, + rotate_prob: FloatLike = ScheduledFloat((0.0, 0.0), (4000.0, 0.2), (8000.0, 0.0))): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) + self.rotate_prob = copy.deepcopy(rotate_prob) def forward(self, src: Tensor) -> Tensor: """ @@ -1013,11 +1016,31 @@ def forward(self, src: Tensor) -> Tensor: else: src = self.proj(src) - src = torch.stack((src[..., 0::2], src[..., 1::2]), + + a, b = src[..., 0::2], src[..., 1::2] + + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + a, b = self._random_rotate(a, b) + + src = torch.stack((a, b), dim=1) # (seq_len, 2, batch_size, in_channels // 2) src = src.reshape(seq_len * 2, batch_size, in_channels // 2) return src + def _random_rotate(self, a: Tensor, b: Tensor): + rotate_prob = float(self.rotate_prob) + if rotate_prob == 0.0: + return a, b + + mean = 0.5 * (a + b) + diff = 0.5 * (b - a) + + diff_scale = torch.empty_like(a[..., :1]).uniform_(-1.0, 1.0) + diff = diff * diff_scale + return mean - diff, mean + diff + + class CompactRelPositionalEncoding(torch.nn.Module): """ From 1f37d88d8cd85b3203f44fa18b590183795a5e72 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Feb 2025 19:33:09 +0800 Subject: [PATCH 0101/1191] Bug fix, do what I intended, but with larger rotate_prob. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 78f65c47ca..fd6e8e39ba 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -995,7 +995,7 @@ class InvertibleUpsample(torch.nn.Module): """ def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0, - rotate_prob: FloatLike = ScheduledFloat((0.0, 0.0), (4000.0, 0.2), (8000.0, 0.0))): + rotate_prob: FloatLike = ScheduledFloat((0.0, 0.0), (4000.0, 0.5), (8000.0, 0.0))): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) @@ -1036,7 +1036,10 @@ def _random_rotate(self, a: Tensor, b: Tensor): mean = 0.5 * (a + b) diff = 0.5 * (b - a) - diff_scale = torch.empty_like(a[..., :1]).uniform_(-1.0, 1.0) + x = a[..., :1] + diff_scale = torch.where(torch.rand_like(x) < rotate_prob, + torch.empty_like(x).uniform_(-1.0, 1.0), + torch.ones_like(x)) diff = diff * diff_scale return mean - diff, mean + diff From 9697080b30a769e30661b20e8de59f0228052e84 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Feb 2025 21:22:06 +0800 Subject: [PATCH 0102/1191] make rotate_prob start out at 1.0. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fd6e8e39ba..b4dd1be78b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -995,7 +995,7 @@ class InvertibleUpsample(torch.nn.Module): """ def __init__(self, channels: int, proj_dim: int, penalty_scale: float = 1000.0, - rotate_prob: FloatLike = ScheduledFloat((0.0, 0.0), (4000.0, 0.5), (8000.0, 0.0))): + rotate_prob: FloatLike = ScheduledFloat((0.0, 1.0), (8000.0, 0.0))): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) From 68599f3e51e177d831d5e72cfaac3ccf621fb5e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Feb 2025 23:34:17 +0800 Subject: [PATCH 0103/1191] Bug fix in streaming_forward, there will be more bugs --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b4dd1be78b..61bd50186f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -92,7 +92,6 @@ class Zipformer2(EncoderInterface): chunks. Must not be less than cnn_module_kernel (after factoring in rounding and downsampling); an error will be thrown if this is violated. """ - def __init__( self, output_downsampling_factor: int = 2, @@ -202,7 +201,6 @@ def set_downsample_factor(cur_downsample, ds): self.encoders = nn.ModuleList(encoders) - def get_chunk_info(self) -> Tuple[int, int]: """ Returns chunk_size and left_context_chunks. @@ -391,7 +389,7 @@ def streaming_forward( layer_offset += num_layers new_states += new_layer_states - x = x[..., :self.encoder_dim[-1]] + x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 From 13c7e13a7ee753c499d193bf666e527dde53565e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Feb 2025 00:17:20 +0800 Subject: [PATCH 0104/1191] Implement reconstruction loss; remove random rotation --- egs/librispeech/ASR/zipformer/model.py | 60 +++++++++++++++++++- egs/librispeech/ASR/zipformer/subsampling.py | 4 +- egs/librispeech/ASR/zipformer/train.py | 11 +++- egs/librispeech/ASR/zipformer/zipformer.py | 29 +--------- 4 files changed, 72 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index c7dbe1e0ad..de78668da3 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -122,6 +122,10 @@ def __init__( else: assert attention_decoder is None + self.reconstruction_proj = torch.nn.Linear( + encoder_dim, 2 * encoder_embed.in_channels) + self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -478,4 +482,58 @@ def forward( else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss + reconstruction_loss = self.forward_reconstruction_loss(x, encoder_out, + encoder_out_lens, + use_cr_ctc) + if use_cr_ctc: + reconstruction_loss = reconstruction_loss * 0.5 + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss + + + def forward_reconstruction_loss(self, + log_mels: Tensor, + encoder_out: Tensor, + encoder_out_lens: Tensor, + use_cr_ctc: bool): + """ + Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If + use_cr_ctc then we swap the first and second halves of the batch. + + Args: + log_mels: log-mel features of shape (batch_size, T, num_mels) + encoder_out: embeddings of shape (T_embed, batch_size, encoder_dim) + """ + if use_cr_ctc: + batch_size = log_mels.shape[0] + log_mels = torch.roll(log_mels, N // 2, dims=0) + num_mels = log_mels.shape[2] + + pred_mels = self.reconstruction_proj(encoder_out) # (T_embed, batch_size, 2 * num_mels) + T_embed = pred_mels.shape[0] + pred_mels = pred_mels.reshape(T_embed, batch_size, 2, num_mels) + pred_mels = pred_mels.permute(1, 0, 2, 3).reshape(batch_size, T_embed * 2, num_mels) + + excess_frames = log_mels.shape[1] - pred_mels.shape[1] + assert 0 < excess_frames < 10 # should be around 7 or 8I believe. + if random.random() < 0.01: + logging.info("excess_frames = ", excess_frames) # TODO: remove this line + + T = pred_mels.shape[1] + offset = excess_frames // 2 + pred_mels = pred_mels[:, offset:offset+T] + + + lens = encoder_out_lens * 2 + pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions + assert pad_mask.shape == (batch_size, T) + pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position + # padd_mask: (batch_size, T, 1) + + + # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly + # helps to down-weight the effect of very silent silences. + loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, + reduction='none', beta=1.0) + loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. + return loss diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 0d6c2d813d..946fa538ac 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -228,6 +228,7 @@ def __init__( bottleneck dimension for 1d squeeze-excite """ assert in_channels >= 7 + self.in_channels = in_channels super().__init__() # The ScaleGrad module is there to prevent the gradients @@ -308,9 +309,6 @@ def forward( """ # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. x = self.conv(x) x = self.convnext(x) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7f456e6816..3f6485282b 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -455,6 +455,13 @@ def get_parser(): help="Scale for consistency-regularization loss.", ) + parser.add_argument( + "--reconstruction-loss-scale", + type=float, + default=0.01, + help="Scale for log-mel reconstruction loss.", + ) + parser.add_argument( "--time-mask-ratio", type=float, @@ -922,7 +929,7 @@ def compute_loss( supervision_segments = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -959,6 +966,8 @@ def compute_loss( if use_cr_ctc: loss += params.cr_loss_scale * cr_loss + loss += params.reonstruction_loss_scale * reconstruction_loss + if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 61bd50186f..f7b316b073 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -992,12 +992,10 @@ class InvertibleUpsample(torch.nn.Module): """ def __init__(self, channels: int, proj_dim: int, - penalty_scale: float = 1000.0, - rotate_prob: FloatLike = ScheduledFloat((0.0, 1.0), (8000.0, 0.0))): + penalty_scale: float = 1000.0): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) - self.rotate_prob = copy.deepcopy(rotate_prob) def forward(self, src: Tensor) -> Tensor: """ @@ -1014,34 +1012,11 @@ def forward(self, src: Tensor) -> Tensor: else: src = self.proj(src) - - a, b = src[..., 0::2], src[..., 1::2] - - - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - a, b = self._random_rotate(a, b) - - src = torch.stack((a, b), + src = torch.stack((src[..., 0::2], src[..., 1::2]), dim=1) # (seq_len, 2, batch_size, in_channels // 2) src = src.reshape(seq_len * 2, batch_size, in_channels // 2) return src - def _random_rotate(self, a: Tensor, b: Tensor): - rotate_prob = float(self.rotate_prob) - if rotate_prob == 0.0: - return a, b - - mean = 0.5 * (a + b) - diff = 0.5 * (b - a) - - x = a[..., :1] - diff_scale = torch.where(torch.rand_like(x) < rotate_prob, - torch.empty_like(x).uniform_(-1.0, 1.0), - torch.ones_like(x)) - diff = diff * diff_scale - return mean - diff, mean + diff - - class CompactRelPositionalEncoding(torch.nn.Module): """ From f829b75866b5451c0d13ec23b20a2291f060f308 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Feb 2025 00:18:40 +0800 Subject: [PATCH 0105/1191] add an import --- egs/librispeech/ASR/zipformer/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index de78668da3..7986ce1ee5 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -21,6 +21,7 @@ import k2 import torch import torch.nn as nn +from torch import Tensor from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment from scaling import ScaledLinear From e96785e08bbb9cf61633cfa9bb8115933e02a887 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Feb 2025 00:50:34 +0800 Subject: [PATCH 0106/1191] Bug fixes --- egs/librispeech/ASR/zipformer/model.py | 24 ++++++++++-------------- egs/librispeech/ASR/zipformer/train.py | 3 ++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 7986ce1ee5..3e17b9ebb9 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -124,7 +124,7 @@ def __init__( assert attention_decoder is None self.reconstruction_proj = torch.nn.Linear( - encoder_dim, 2 * encoder_embed.in_channels) + encoder_dim, 4 * encoder_embed.in_channels) self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) def forward_encoder( @@ -503,29 +503,25 @@ def forward_reconstruction_loss(self, Args: log_mels: log-mel features of shape (batch_size, T, num_mels) - encoder_out: embeddings of shape (T_embed, batch_size, encoder_dim) + encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) """ if use_cr_ctc: batch_size = log_mels.shape[0] - log_mels = torch.roll(log_mels, N // 2, dims=0) + log_mels = torch.roll(log_mels, batch_size // 2, dims=0) num_mels = log_mels.shape[2] - pred_mels = self.reconstruction_proj(encoder_out) # (T_embed, batch_size, 2 * num_mels) - T_embed = pred_mels.shape[0] - pred_mels = pred_mels.reshape(T_embed, batch_size, 2, num_mels) - pred_mels = pred_mels.permute(1, 0, 2, 3).reshape(batch_size, T_embed * 2, num_mels) + pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) + T_embed = pred_mels.shape[1] + pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) excess_frames = log_mels.shape[1] - pred_mels.shape[1] - assert 0 < excess_frames < 10 # should be around 7 or 8I believe. - if random.random() < 0.01: - logging.info("excess_frames = ", excess_frames) # TODO: remove this line + assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. T = pred_mels.shape[1] - offset = excess_frames // 2 - pred_mels = pred_mels[:, offset:offset+T] + offset = 3 # i found excess_frames = 5 one time. + log_mels = log_mels[:, offset:offset+T] - - lens = encoder_out_lens * 2 + lens = encoder_out_lens * 4 pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions assert pad_mask.shape == (batch_size, T) pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3f6485282b..f4344d8ccc 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -966,7 +966,7 @@ def compute_loss( if use_cr_ctc: loss += params.cr_loss_scale * cr_loss - loss += params.reonstruction_loss_scale * reconstruction_loss + loss += params.reconstruction_loss_scale * reconstruction_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -987,6 +987,7 @@ def compute_loss( info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.use_cr_ctc: info["cr_loss"] = cr_loss.detach().cpu().item() + info["recon_loss"] = reconstruction_loss if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() From 212f7c4410ff28f453b18853c4a13841329d4e18 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Feb 2025 13:21:11 +0800 Subject: [PATCH 0107/1191] Add some diagnostic printouts. --- egs/librispeech/ASR/zipformer/scaling.py | 4 +++- egs/librispeech/ASR/zipformer/zipformer.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b53651bef8..785dac3ddb 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1017,7 +1017,7 @@ class ScaleBalancer(torch.nn.Module): def __init__(self): super().__init__() self.noise_scale = 0.2 - + self.name = None def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: @@ -1036,6 +1036,8 @@ def forward(self, x: Tensor) -> Tensor: x_sq_mean = (x_sq * mask).mean() / prob noise = ((self.noise_scale * (1 + x_sq_mean)) * mask) * torch.randn_like(x) + if random.random() < 0.001: + logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, noise_rms={self.noise_scale*(1+(x**2).mean()).item()}") return x + noise diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f7b316b073..4ea2b249cd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -725,7 +725,7 @@ def __init__( self.encoder_pos = CompactRelPositionalEncoding( pos_dim, dropout_rate=0.0, length_factor=1.0 ) - + self.name = None self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) @@ -791,6 +791,11 @@ def _add_noise_to_bypass(self, x: Tensor): # training, should reduce certain biases caused by roundoff which otherwise # tend to lead the embeddings to get smaller in scale. noise_scale = (0.5 * bypass_scale) * ((x ** 2).mean(dim=-1, keepdim=True) + 1.0) + + if random.random() < 0.001: + logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, bypass_scale={bypass_scale}, noise_rms={noise_scale.mean()}") + + return x + torch.randn_like(x) * noise_scale From 67637a32373eeaec532d2841e0add1bbfcdfe99d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Feb 2025 13:48:53 +0800 Subject: [PATCH 0108/1191] Change how the masking is done in ScaleBalancer, mask entire sequences; reduce noise_scale by 8. --- egs/librispeech/ASR/zipformer/scaling.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 785dac3ddb..f0f7e441ba 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1016,28 +1016,28 @@ class ScaleBalancer(torch.nn.Module): def __init__(self): super().__init__() - self.noise_scale = 0.2 + self.noise_scale = 0.05 self.name = None def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return _no_op(x) - x_shape = list(x.shape) - x_shape[-1] = 1 + + # the mask is random over the batch dim. + mask_shape = [1, x.shape[1], 1] # we estimate the rms value of x from about 1 in 20 embedding vectors, or at most about 500 # embedding vectors. This is to prevent the grads propagated this way from being so small # that when added to the main gradient term they make no difference, in fp16. - r = torch.rand(*x_shape, device=x.device) + r = torch.rand(*mask_shape, device=x.device) prob = 0.01 mask = (r < prob).to(x.dtype) - x_sq = (x ** 2).mean(dim=-1, keepdim=True) - x_sq_mean = (x_sq * mask).mean() / prob + x_sq = (x ** 2).mean(dim=(0,2), keepdim=True) - noise = ((self.noise_scale * (1 + x_sq_mean)) * mask) * torch.randn_like(x) - if random.random() < 0.001: - logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, noise_rms={self.noise_scale*(1+(x**2).mean()).item()}") + noise = ((self.noise_scale * (1 + x_sq)) * mask) * torch.randn_like(x) + if random.random() < 0.001 or True: + logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, noise_rms={self.noise_scale*(1+x_sq.mean()).item()}") return x + noise From fb4d0a4348454288e0a3b496585ee3275de219bc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Feb 2025 13:54:10 +0800 Subject: [PATCH 0109/1191] factor of 0.5 on noise_scale --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index f0f7e441ba..557e0f5922 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1035,8 +1035,8 @@ def forward(self, x: Tensor) -> Tensor: mask = (r < prob).to(x.dtype) x_sq = (x ** 2).mean(dim=(0,2), keepdim=True) - noise = ((self.noise_scale * (1 + x_sq)) * mask) * torch.randn_like(x) - if random.random() < 0.001 or True: + noise = (((0.5 * self.noise_scale) * (1 + x_sq)) * mask) * torch.randn_like(x) + if random.random() < 0.001: logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, noise_rms={self.noise_scale*(1+x_sq.mean()).item()}") return x + noise From 0d312c5f8324e2a5ccf6b1b677ea3cec70b7d040 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Feb 2025 10:28:36 +0800 Subject: [PATCH 0110/1191] set lr_scale=0.5 in OrthogonalLinear --- egs/librispeech/ASR/zipformer/scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 557e0f5922..e1830e49bb 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -567,6 +567,7 @@ def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): self.penalty_scale = copy.deepcopy(penalty_scale) self.min_product_scale = 0.01 self.name = None # will be set from training loop. for printing penalty. + self.lr_scale = 0.5 with torch.no_grad(): # this is not orthogonal but should quickly become so. From dfa60d627b9a7369c07d109705db9b1dd07e036d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Feb 2025 17:20:00 +0800 Subject: [PATCH 0111/1191] Replace out_balancer of Conv2dSubsampling with normal Balancer. --- egs/librispeech/ASR/zipformer/subsampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 946fa538ac..94b3684f8f 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -275,7 +275,8 @@ def __init__( self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - self.out_balancer = ScaleBalancer() + self.out_balancer = Balancer( + out_channels, channel_dim=-1, min_abs=0.2, max_abs=1.0) # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together From 4078401411aa4d8d88544788c105e1c1706c4028 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Feb 2025 11:28:08 +0800 Subject: [PATCH 0112/1191] Increase whitening limit of self_attn modules according to ratio of value_dim to embed_dim. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4ea2b249cd..4576adabc8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1534,9 +1534,13 @@ def __init__( num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 ) + f = max(1.0, embed_dim / (num_heads * value_head_dim)) + # the whitening metric cannot be less than f because of the rank imposed + # by the bottleneck. the final whitening limit will be (2.0*3.0) times f, + # i.e. 6 times greater than the mathematical smallest value it can have. self.whiten = Whiten( num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), + whitening_limit=_whitening_schedule(f * 2.0, ratio=3.0), prob=(0.025, 0.25), grad_scale=0.01, ) From 9922308e529e2a6ce1c232072399eb4dc7a1a6cd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Feb 2025 17:45:09 +0800 Subject: [PATCH 0113/1191] Widen limits of min_positive,max_positive on balance_keys. --- egs/librispeech/ASR/zipformer/zipformer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4576adabc8..4aef992b19 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -523,7 +523,16 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) - self.balancer = ScaleBalancer() + # warm up the grad_scale slowly so it does not cause instability + # near the beginning of training. + self.balancer = Balancer( + embed_dim, + channel_dim=-1, + min_abs=0.1, + max_abs=1.0, + grad_scale=ScheduledFloat((0.0, 0.0), (10000.0, 0.005)), + ) + self.norm = BiasNorm(embed_dim) @@ -1211,8 +1220,8 @@ def __init__( self.balance_keys = Balancer( key_head_dim * num_heads, channel_dim=-1, - min_positive=0.4, - max_positive=0.6, + min_positive=0.1, + max_positive=0.9, min_abs=0.0, max_abs=100.0, prob=0.025, From d72d1f20157a02bb7051891a7f2de725ad23d89c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Feb 2025 13:16:41 +0800 Subject: [PATCH 0114/1191] Add noise_scale=1.0e-04 in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 57 +++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 949626b8e6..055884e5ee 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -240,8 +240,16 @@ def scaling_step(group, p, state, grad): def momentum_step(group, p, state, grad): + # This takes care of momentum and temporary-noise. delta = scaling_step(group, p, state, grad) beta1 = group["betas"][0] + + # see the very long comment below for how noise_factor is set, it's to ensure + # the variance of temporary noise present at any given point equals + # group["noise_scale"]. + noise_factor = group["noise_scale"] * ((1 - beta1**2) ** 0.5) / beta1 + noise = torch.randn_like(p) * noise_factor + try: stored_delta = state["delta"] except KeyError: @@ -249,10 +257,50 @@ def momentum_step(group, p, state, grad): state["delta"] = stored_delta stored_delta.mul_(beta1) stored_delta.add_(delta, alpha=(1-beta1)) + + # we need to add the noise to stored_delta below with an alpha (i.e. a scale) that will ensure + # that it eventually gets totally subtracted. At the end of this function we'll return + # stored_delta + noise (just take that as a given), + # so the total scale of the noise-that-we-add-on-this-step, that will eventually be added + # to the parameters, will be + # 1.0 [for the "+ noise" term below] + + # \alpha \sum_i=0^\infty beta1^i + # so we need 1.0 + \alpha / (1-beta1) = 0, [this is what makes the noise "temporary"]. + # which we solve to \alpha = (beta1-1). + # Now we want to know the total variance of the noise present at any given point, so we + # can set the variance to group["noise_scale"] ** 2 (noise_scale is user-specified). + # + # The total variance of noise on step t will be: + # noise_factor**2 * \sum_{k=0}^\infty var_of_noise_from_step[t-k] (eqn:1) + # .. note, the value of t doesn't matter. + # + # var_of_noise_from_step[t-k] = scale_of_noise_from_step[t-k] ** 2. + # (by scale_of_noise_from_step[t-k] we mean the scale excluding the factor of noise_factor). + # + # We can write scale_of_noise_from_step[t-k] as 1 + {a difference of infinite sums}; + # the 1 comes from the "+ noise" when we return below. + # It is: + # scale_of_noise_from_step[t-k] = 1 + \alpha / (1-beta1) - \alpha \beta1^{k+1} / (1-beta1) + # (the 1/(1-beta1) both come from infinite sums: \sum_j=0^\infty beta1**k. Interpret + # the above as: 1 + {sum-of-all-terms} - {sum-of-all-terms-past-k}. + # Substituting alpha=beta1-1: + # scale_of_noise_from_step[t-k] = 1 + (beta1-1) / (1-beta1) - (beta1-1) \beta1^{k+1} / (1-beta1) + # scale_of_noise_from_step[t-k] = beta1^{k+1} # everything else cancels. + # So: + # var_of_noise_from_step[t-k] = (beta1**2)^{k+1} + # So: + # \sum_{k=0}^\infty var_of_noise_from_step[t-k] = (beta1**2) / (1 - beta1**2) + # + # Now we want the expression of (eqn:1) to equal group["noise_scale"] ** 2, i.e. + # noise_factor**2 * (beta1**2) / (1 - beta1**2) = group["noise_scale"] ** 2 + # i.e. + # noise_factor = group["noise_scale"] * sqrt(1-beta1**2) / beta1. + stored_delta.add_(noise, alpha=(beta1-1)) + # we don't bother doing the "bias correction" part of Adam for beta1 because this is just # an edge effect that affects the first 10 or so batches; and the effect of not doing it # is just to do a slower update for the first few batches, which will help stability. - return stored_delta + return stored_delta + noise @@ -287,6 +335,11 @@ class ScaledAdam(BatchedOptimizer): as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the scaling factor on the learning rate of p_scale. eps: A general-purpose epsilon to prevent division by zero + noise_scale: The amount of temporary parameter noise that we add. This is added + in an absolute sense, not scaled by the parameter RMS; it is a mechanism + to attempt to balance the parameter rms values and avoid excessive + sensitivity to any parameters. It is "temporary" because we subtract it + again via the moving average. param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll constrain the rms of each non-scalar parameter tensor to be >= this value) @@ -311,6 +364,7 @@ def __init__( eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=3.0, + noise_scale=1.0e-04, scalar_max=10.0, size_update_period=4, clipping_update_period=100, @@ -324,6 +378,7 @@ def __init__( eps=eps, param_min_rms=param_min_rms, param_max_rms=param_max_rms, + noise_scale=noise_scale, scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, From 2c709928b170d9b2c67d4760081c10e13db9a776 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Feb 2025 14:31:36 +0800 Subject: [PATCH 0115/1191] Remove Balancers --- egs/librispeech/ASR/zipformer/decoder.py | 24 ------ egs/librispeech/ASR/zipformer/subsampling.py | 28 ------- egs/librispeech/ASR/zipformer/zipformer.py | 87 +------------------- 3 files changed, 1 insertion(+), 138 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 7ce44495bf..357f98a807 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -58,17 +58,6 @@ def __init__( num_embeddings=vocab_size, embedding_dim=decoder_dim, ) - # the balancers are to avoid any drift in the magnitude of the - # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) self.blank_id = blank_id @@ -85,20 +74,10 @@ def __init__( groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) else: # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` # when inference with torch.jit.script and context_size == 1 self.conv = nn.Identity() - self.balancer2 = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -116,8 +95,6 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: # at utterance start, we use negative ids in beam_search.py embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - embedding_out = self.balancer(embedding_out) - if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: @@ -129,6 +106,5 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) embedding_out = F.relu(embedding_out) - embedding_out = self.balancer2(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 94b3684f8f..8d2f3c7b64 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -68,15 +68,6 @@ def __init__( in_channels=channels, out_channels=hidden_channels, kernel_size=1 ) - self.hidden_balancer = Balancer( - hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - self.activation = SwooshL() self.pointwise_conv2 = ScaledConv2d( in_channels=hidden_channels, @@ -85,15 +76,6 @@ def __init__( initial_scale=0.01, ) - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) - self.out_whiten = Whiten( num_groups=1, whitening_limit=5.0, @@ -129,7 +111,6 @@ def forward_internal( bypass = x x = self.depthwise_conv(x) x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) @@ -137,7 +118,6 @@ def forward_internal( x = x * layer_skip_mask x = bypass + x - x = self.out_balancer(x) if x.requires_grad: x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last @@ -185,7 +165,6 @@ def streaming_forward( groups=self.depthwise_conv.groups, ) x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) @@ -245,7 +224,6 @@ def __init__( padding=(0, 1), # (time, freq) ), ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, max_abs=1.0), SwooshR(), nn.Conv2d( in_channels=layer1_channels, @@ -254,7 +232,6 @@ def __init__( stride=2, padding=0, ), - Balancer(layer2_channels, channel_dim=1, max_abs=4.0), SwooshR(), nn.Conv2d( in_channels=layer2_channels, @@ -262,7 +239,6 @@ def __init__( kernel_size=3, stride=(1, 2), # (time, freq) ), - Balancer(layer3_channels, channel_dim=1, max_abs=4.0), SwooshR(), ) @@ -275,9 +251,6 @@ def __init__( self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - self.out_balancer = Balancer( - out_channels, channel_dim=-1, min_abs=0.2, max_abs=1.0) - # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together # many copies of this extra gradient term. @@ -320,7 +293,6 @@ def forward( # now x: (N, (T-7)//2, out_width * layer3_channels)) x = self.out(x) - x = self.out_balancer(x) # Now x is of shape (N, (T-7)//2, odim) x = self.out_whiten(x) x = self.out_norm(x) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4aef992b19..ac3088fada 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -523,16 +523,6 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) - # warm up the grad_scale slowly so it does not cause instability - # near the beginning of training. - self.balancer = Balancer( - embed_dim, - channel_dim=-1, - min_abs=0.1, - max_abs=1.0, - grad_scale=ScheduledFloat((0.0, 0.0), (10000.0, 0.005)), - ) - self.norm = BiasNorm(embed_dim) @@ -582,8 +572,6 @@ def forward( src = self.bypass(src_orig, src) - src = self.balancer(src) - return self.norm(src) def streaming_forward( @@ -1209,24 +1197,6 @@ def __init__( grad_scale=0.025, ) - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be sufficient to fix the problem. - self.balance_keys = Balancer( - key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.1, - max_positive=0.9, - min_abs=0.0, - max_abs=100.0, - prob=0.025, - ) - # linear transformation for positional encoding. self.linear_pos = ScaledLinear( pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 @@ -1277,7 +1247,7 @@ def forward( ) q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + k = self.whiten_keys(k) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) @@ -1657,15 +1627,6 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(embed_dim, feedforward_dim) - self.hidden_balancer = Balancer( - feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( feedforward_dim, @@ -1686,7 +1647,6 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): def forward(self, x: Tensor): x = self.in_proj(x) - x = self.hidden_balancer(x) # out_proj contains SwooshL activation, then dropout, then linear. x = self.out_proj(x) x = self.out_whiten(x) @@ -1713,18 +1673,6 @@ def __init__( self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) self.tanh = nn.Tanh() self.identity1 = Identity() # for diagnostics. @@ -1770,7 +1718,6 @@ def forward( # s will go through tanh. - s = self.balancer(s) s = self.tanh(s) s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) @@ -1890,27 +1837,6 @@ def __init__( # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) self.activation1 = Identity() # for diagnostics @@ -1932,15 +1858,6 @@ def __init__( ) ) - self.balancer2 = Balancer( - bottleneck_dim, - channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - self.whiten = Whiten( num_groups=1, whitening_limit=_whitening_schedule(7.5), @@ -1977,7 +1894,6 @@ def forward( x = self.in_proj(x) # (time, batch, 2*channels) x, s = x.chunk(2, dim=2) - s = self.balancer1(s) s = self.sigmoid(s) x = self.activation1(x) # identity. x = x * s @@ -2004,7 +1920,6 @@ def forward( else: x = self.depthwise_conv(x) - x = self.balancer2(x) x = x.permute(2, 0, 1) # (time, batch, channels) x = self.whiten(x) # (time, batch, channels) From 5ae446d00d49651b58a37f2e124fd3c8eb91d2f7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Feb 2025 22:29:52 +0800 Subject: [PATCH 0116/1191] Reverse the adding of noise_scale. --- egs/librispeech/ASR/zipformer/optim.py | 57 +------------------------- 1 file changed, 1 insertion(+), 56 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 055884e5ee..949626b8e6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -240,16 +240,8 @@ def scaling_step(group, p, state, grad): def momentum_step(group, p, state, grad): - # This takes care of momentum and temporary-noise. delta = scaling_step(group, p, state, grad) beta1 = group["betas"][0] - - # see the very long comment below for how noise_factor is set, it's to ensure - # the variance of temporary noise present at any given point equals - # group["noise_scale"]. - noise_factor = group["noise_scale"] * ((1 - beta1**2) ** 0.5) / beta1 - noise = torch.randn_like(p) * noise_factor - try: stored_delta = state["delta"] except KeyError: @@ -257,50 +249,10 @@ def momentum_step(group, p, state, grad): state["delta"] = stored_delta stored_delta.mul_(beta1) stored_delta.add_(delta, alpha=(1-beta1)) - - # we need to add the noise to stored_delta below with an alpha (i.e. a scale) that will ensure - # that it eventually gets totally subtracted. At the end of this function we'll return - # stored_delta + noise (just take that as a given), - # so the total scale of the noise-that-we-add-on-this-step, that will eventually be added - # to the parameters, will be - # 1.0 [for the "+ noise" term below] + - # \alpha \sum_i=0^\infty beta1^i - # so we need 1.0 + \alpha / (1-beta1) = 0, [this is what makes the noise "temporary"]. - # which we solve to \alpha = (beta1-1). - # Now we want to know the total variance of the noise present at any given point, so we - # can set the variance to group["noise_scale"] ** 2 (noise_scale is user-specified). - # - # The total variance of noise on step t will be: - # noise_factor**2 * \sum_{k=0}^\infty var_of_noise_from_step[t-k] (eqn:1) - # .. note, the value of t doesn't matter. - # - # var_of_noise_from_step[t-k] = scale_of_noise_from_step[t-k] ** 2. - # (by scale_of_noise_from_step[t-k] we mean the scale excluding the factor of noise_factor). - # - # We can write scale_of_noise_from_step[t-k] as 1 + {a difference of infinite sums}; - # the 1 comes from the "+ noise" when we return below. - # It is: - # scale_of_noise_from_step[t-k] = 1 + \alpha / (1-beta1) - \alpha \beta1^{k+1} / (1-beta1) - # (the 1/(1-beta1) both come from infinite sums: \sum_j=0^\infty beta1**k. Interpret - # the above as: 1 + {sum-of-all-terms} - {sum-of-all-terms-past-k}. - # Substituting alpha=beta1-1: - # scale_of_noise_from_step[t-k] = 1 + (beta1-1) / (1-beta1) - (beta1-1) \beta1^{k+1} / (1-beta1) - # scale_of_noise_from_step[t-k] = beta1^{k+1} # everything else cancels. - # So: - # var_of_noise_from_step[t-k] = (beta1**2)^{k+1} - # So: - # \sum_{k=0}^\infty var_of_noise_from_step[t-k] = (beta1**2) / (1 - beta1**2) - # - # Now we want the expression of (eqn:1) to equal group["noise_scale"] ** 2, i.e. - # noise_factor**2 * (beta1**2) / (1 - beta1**2) = group["noise_scale"] ** 2 - # i.e. - # noise_factor = group["noise_scale"] * sqrt(1-beta1**2) / beta1. - stored_delta.add_(noise, alpha=(beta1-1)) - # we don't bother doing the "bias correction" part of Adam for beta1 because this is just # an edge effect that affects the first 10 or so batches; and the effect of not doing it # is just to do a slower update for the first few batches, which will help stability. - return stored_delta + noise + return stored_delta @@ -335,11 +287,6 @@ class ScaledAdam(BatchedOptimizer): as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the scaling factor on the learning rate of p_scale. eps: A general-purpose epsilon to prevent division by zero - noise_scale: The amount of temporary parameter noise that we add. This is added - in an absolute sense, not scaled by the parameter RMS; it is a mechanism - to attempt to balance the parameter rms values and avoid excessive - sensitivity to any parameters. It is "temporary" because we subtract it - again via the moving average. param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll constrain the rms of each non-scalar parameter tensor to be >= this value) @@ -364,7 +311,6 @@ def __init__( eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=3.0, - noise_scale=1.0e-04, scalar_max=10.0, size_update_period=4, clipping_update_period=100, @@ -378,7 +324,6 @@ def __init__( eps=eps, param_min_rms=param_min_rms, param_max_rms=param_max_rms, - noise_scale=noise_scale, scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, From 3eb36dca23e7e50763a324aadb970bd39d6f18cd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Feb 2025 22:32:51 +0800 Subject: [PATCH 0117/1191] change log_eps to eps and make min=1.0, max=2.0 --- egs/librispeech/ASR/zipformer/scaling.py | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e1830e49bb..3b9f28d9e4 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -365,7 +365,7 @@ def backward(ctx, x_grad, *args): class BiasNormFunction(torch.autograd.Function): # This computes: - # scales = (torch.mean(x ** 2 + log_eps.exp(), keepdim=True)) ** -0.5 * log_scale.exp() + # scales = (torch.mean(x ** 2 + eps, keepdim=True)) ** -0.5 * log_scale.exp() # return x * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for @@ -374,7 +374,7 @@ class BiasNormFunction(torch.autograd.Function): def forward( ctx, x: Tensor, - log_eps: Tensor, + eps: Tensor, power: Tensor, scale: Tensor, channel_dim: int, @@ -384,11 +384,11 @@ def forward( ctx.channel_dim = channel_dim x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = scale * (x_sq ** power + log_eps.exp()) ** (-0.5 / power) + scales = scale * (x_sq ** power + eps) ** (-0.5 / power) ans = x * scales ctx.save_for_backward( x.detach(), - log_eps.detach(), + eps.detach(), power.detach(), scale.detach(), ) @@ -396,18 +396,18 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, log_eps, power, scale = ctx.saved_tensors + x, eps, power, scale = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): - x, power, log_eps, scale = x.to(torch.float32), power.to(torch.float32), log_eps.to(torch.float32), scale.to(torch.float32) - x, power, log_eps, scale = x.detach(), power.detach(), log_eps.detach(), scale.detach() + x, power, eps, scale = x.to(torch.float32), power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, power, eps, scale = x.detach(), power.detach(), eps.detach(), scale.detach() x.requires_grad = True - log_eps.requires_grad = True + eps.requires_grad = True power.requires_grad = True scale.requires_grad = True with torch.enable_grad(): x_sq = torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) - scales = scale * (x_sq ** power + log_eps.exp()) ** (-0.5 / power) + scales = scale * (x_sq ** power + eps) ** (-0.5 / power) ans = x * scales ans.backward(gradient=ans_grad.to(torch.float32)) @@ -416,7 +416,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(log_eps.grad), c(power.grad), c(scale.grad), None + return x.grad, c(eps.grad), c(power.grad), c(scale.grad), None class BiasNorm(torch.nn.Module): @@ -456,7 +456,7 @@ def __init__( self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(2.0)) - self.log_eps = nn.Parameter(torch.tensor(0.0)) + self.eps = nn.Parameter(torch.tensor(1.0)) self.power = nn.Parameter(torch.tensor(1.0)) self.name = None @@ -470,11 +470,11 @@ def forward(self, x: Tensor) -> Tensor: if channel_dim < 0: channel_dim += x.ndim x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = self.scale * (x_sq ** self.power + self.log_eps.exp()) ** (-0.5 / self.power) + scales = self.scale * (x_sq ** self.power + self.eps) ** (-0.5 / self.power) return (x * scales) - log_eps = limit_param_value( - self.log_eps, min=-3.0, max=3.0, training=self.training) + eps = limit_param_value( + self.eps, min=1.0, max=2.0, training=self.training) power = limit_param_value( self.power, min=0.25, max=2.0, training=self.training) @@ -484,10 +484,10 @@ def forward(self, x: Tensor) -> Tensor: if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, power={power.item()}, eps={log_eps.exp()}, eps**(1/power)={(log_eps.exp() ** (1/power))}, scale={scale.item()}, (eps**(0.5/power))/x_rms={(log_eps.exp()**(0.5/power))/x_rms}") + logging.info(f"name={self.name}: x_rms={x_rms}, power={power.item()}, eps={eps.item()}, eps**(1/power)={(eps ** (1/power))}, scale={scale.item()}, (eps**(0.5/power))/x_rms={(eps**(0.5/power))/x_rms}") return BiasNormFunction.apply( - x, log_eps, power, scale, self.channel_dim, + x, eps, power, scale, self.channel_dim, ) From 4bfe2d17680a92685816655bb5e4ab1e150f2f10 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Feb 2025 22:56:33 +0800 Subject: [PATCH 0118/1191] Introduce ScaleLimiter and use it in frontend. --- egs/librispeech/ASR/zipformer/scaling.py | 63 +++++++++++++------- egs/librispeech/ASR/zipformer/subsampling.py | 6 +- egs/librispeech/ASR/zipformer/zipformer.py | 3 +- 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3b9f28d9e4..e8e73cb697 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1007,39 +1007,56 @@ def _approx_inverse_erf(x): return _no_op(x) +class ScaleLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, max_scale: float): + ctx.save_for_backward(x) + ctx.max_scale = max_scale + return x -class ScaleBalancer(torch.nn.Module): + @staticmethod + def backward(ctx, y_grad: Tensor): + x, = ctx.saved_tensors + # you could think of loss_scale as like a mask, it's nonzero if + # (x**2).mean() > 1.0, but it starts of small if we are close to 1.0 + # so we don't suddenly add large gradients that could be destabilizing. + eps = 0.01 + loss_scale = eps * ((x ** 2).mean() - ctx.max_scale).relu() + y_grad_rms = (y_grad ** 2).mean().sqrt() + # y_grad_rms is a scaling factor for the gradient contribution, since we + # don't know at this point the total scale of the main loss. + + # the grad of (x ** 2).mean() would be 2 * x. we absorb the factor of 2 + # into eps, which is just an arbitrary smallish value. + return y_grad + (loss_scale * y_grad_rms) * x, None + + +class ScaleLimiter(torch.nn.Module): """ - Tries to make the rms value of the features around 1, using - strategically added noise. This is not per dimension, but globally. + Tries to make the rms value of the features no greater than self.max_scale, by + adding a penalty. This is not per dimension, but globally. Assumes channel dim is -1 and the input shape has >1 dimension. """ - - def __init__(self): + def __init__(self, max_scale: FloatLike = 1.0, prob: FloatLike = 1.0): super().__init__() - self.noise_scale = 0.05 self.name = None + self.max_scale = max_scale + self.prob = prob def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return _no_op(x) - - - # the mask is random over the batch dim. - mask_shape = [1, x.shape[1], 1] - - # we estimate the rms value of x from about 1 in 20 embedding vectors, or at most about 500 - # embedding vectors. This is to prevent the grads propagated this way from being so small - # that when added to the main gradient term they make no difference, in fp16. - r = torch.rand(*mask_shape, device=x.device) - prob = 0.01 - mask = (r < prob).to(x.dtype) - x_sq = (x ** 2).mean(dim=(0,2), keepdim=True) - - noise = (((0.5 * self.noise_scale) * (1 + x_sq)) * mask) * torch.randn_like(x) - if random.random() < 0.001: - logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, noise_rms={self.noise_scale*(1+x_sq.mean()).item()}") - return x + noise + else: + # this in effect adds a penalty to the loss function if + # (x ** 2).mean() > 1.0, the penalty will tend to reduce the value + # of (x ** 2). + if random.random() < 0.001: + logging.info(f"name={self.name}, max_scale={float(self.max_scale)}, prob={float(self.prob)}, x_rms={(x**2).mean().sqrt().item()}") + prob = float(self.prob) + if prob > 0 and random.random() < prob: + return ScaleLimiterFunction.apply(x, float(self.max_scale)) + else: + return x def penalize_abs_values_gt( diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 8d2f3c7b64..c008720ac8 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -22,7 +22,7 @@ import torch from scaling import ( Balancer, - ScaleBalancer, + ScaleLimiter, BiasNorm, Dropout3, FloatLike, @@ -251,6 +251,8 @@ def __init__( self.out = nn.Linear(self.out_width * layer3_channels, out_channels) + self.out_limiter = ScaleLimiter(max_scale=0.5) + # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together # many copies of this extra gradient term. @@ -293,6 +295,7 @@ def forward( # now x: (N, (T-7)//2, out_width * layer3_channels)) x = self.out(x) + x = self.out_limiter(x) # Now x is of shape (N, (T-7)//2, odim) x = self.out_whiten(x) x = self.out_norm(x) @@ -346,6 +349,7 @@ def streaming_forward( x = self.out(x) # Now x is of shape (N, T', odim) + x = self.out_limiter(x) x = self.out_norm(x) if torch.jit.is_scripting() or torch.jit.is_tracing(): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ac3088fada..4d0b6b68e0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -31,8 +31,7 @@ OrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, - Balancer, - ScaleBalancer, + ScaleLimiter, BiasNorm, ChunkCausalDepthwiseConv1d, Dropout2, From 31e0c6df4c157d91c14a0e04e1c2686a650ea91a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 10:07:55 +0800 Subject: [PATCH 0119/1191] Create DeltaDropout and add dropout_ff2 = DeltaDropout(0.1, delta=0.01) in zipformer --- egs/librispeech/ASR/zipformer/scaling.py | 17 +++++++++++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 6 +++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e8e73cb697..ce269ddf36 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1477,6 +1477,23 @@ def forward(self, x: Tensor) -> Tensor: return ans +# DeltaDropout does dropout but you supply a delta and it shrinks the embedding +# element toward zero by at most delta. +class DeltaDropout(nn.Module): + def __init__(self, p: FloatLike, delta: float): + super().__init__() + self.p = p + self.delta = delta + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + + rand = (torch.rand_like(x) < p) * x.abs().clamp_(max=self.delta) * -x.sgn() + return x + rand + + class SwooshLFunction(torch.autograd.Function): """ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4d0b6b68e0..933c97cf53 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -35,6 +35,7 @@ BiasNorm, ChunkCausalDepthwiseConv1d, Dropout2, + DeltaDropout, FloatLike, ScheduledFloat, Whiten, @@ -518,6 +519,9 @@ def __init__( self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + # This is supposed to encourage the scale of the embeddings to get larger if it is too small. + self.dropout_ff2 = DeltaDropout(0.1, delta=0.01) + self.conv_module = ConvolutionModule( embed_dim, cnn_module_kernel, causal=causal ) @@ -567,7 +571,7 @@ def forward( src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask ) - src = src + self.feed_forward2(src) + src = src + self.dropout_ff2(self.feed_forward2(src)) src = self.bypass(src_orig, src) From 00d3e80cf10c485de9a638478cd077e2075f3158 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 10:29:18 +0800 Subject: [PATCH 0120/1191] Add scale_limiter at end of zipformer layer, max_scale=0.5 --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 933c97cf53..1c33a3061b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -526,6 +526,7 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) + self.scale_limiter = ScaleLimiter(max_scale=0.5) self.norm = BiasNorm(embed_dim) @@ -575,6 +576,8 @@ def forward( src = self.bypass(src_orig, src) + src = self.scale_limiter(src) + return self.norm(src) def streaming_forward( From 2099403f880c560d62d7bda433f1700fb2568177 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 13:53:47 +0800 Subject: [PATCH 0121/1191] add weight_max_rms=1.0, weight_min_rms=0.01, remove DeltaDropout --- egs/librispeech/ASR/zipformer/optim.py | 43 +++++++++++++--------- egs/librispeech/ASR/zipformer/scaling.py | 16 -------- egs/librispeech/ASR/zipformer/zipformer.py | 7 +--- 3 files changed, 27 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 949626b8e6..2cd3fa44ee 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -187,18 +187,19 @@ def scaling_step(group, p, state, grad): (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() ) - param_min_rms = group["param_min_rms"] + # would be p.ndim > 1 not p.ndim > 2 but one dim is for batch of tensors. + min_rms = group["weight_min_rms"] if p.ndim > 2 else group["bias_min_rms"] # scale the step size by param_rms. This is the most important "scaling" part of # ScaledAdam - delta *= param_rms.clamp(min=param_min_rms) + delta *= param_rms.clamp(min=min_rms) if step % size_update_period == size_update_period - 1 and step > 0: # This block updates the size of parameter by adding a step ("delta") value in # the direction of either shrinking or growing it. beta2 = group["betas"][1] size_lr = group["lr"] * group["scalar_lr_scale"] - param_max_rms = group["param_max_rms"] + max_rms = group["weight_max_rms"] if p.ndim > 2 else group["bias_max_rms"] eps = group["eps"] batch_size = p.shape[0] # correct beta2 for the size update period: we will have @@ -219,7 +220,7 @@ def scaling_step(group, p, state, grad): -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom ) - is_too_small = param_rms < param_min_rms + is_too_small = param_rms < min_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -228,11 +229,11 @@ def scaling_step(group, p, state, grad): # either direction. scale_step.clamp_(min=-0.1, max=0.1) - # and ensure the parameter rms after update never exceeds param_max_rms. + # and ensure the parameter rms after update never exceeds max_rms. # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the + # max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (max_rms - param_rms) / param_rms) delta.add_(p * scale_step) @@ -287,12 +288,16 @@ class ScaledAdam(BatchedOptimizer): as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the scaling factor on the learning rate of p_scale. eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) + weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of + learning the scale on the parameters. Weight tensors are defined + as anything with more than one element and ndim > 1. + weight_max_rms: Maximum root-mean-square value of weight tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each weight + parameter tensor to be <= this value). + bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with + more than one element and exactly one tensor dimension i.e. ndim == 1. + bias_max_rms: Maximum root-mean-square value of bias tensors, defined as anything with + more than one element and exactly one tensor dimension i.e. ndim == 1. scalar_max: Maximum absolute value for scalar parameters (applicable if your model has any parameters with numel() == 1). size_update_period: The periodicity, in steps, with which we update the size (scale) @@ -309,8 +314,10 @@ def __init__( betas=(0.9, 0.98), scalar_lr_scale=0.1, eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, + weight_min_rms=0.01, + weight_max_rms=1.0, + bias_min_rms=1.0e-05, + bias_max_rms=3.0, scalar_max=10.0, size_update_period=4, clipping_update_period=100, @@ -322,8 +329,10 @@ def __init__( betas=betas, scalar_lr_scale=scalar_lr_scale, eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, + weight_min_rms=weight_min_rms, + weight_max_rms=weight_max_rms, + bias_min_rms=bias_min_rms, + bias_max_rms=bias_max_rms, scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ce269ddf36..6e2ff5481a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1477,22 +1477,6 @@ def forward(self, x: Tensor) -> Tensor: return ans -# DeltaDropout does dropout but you supply a delta and it shrinks the embedding -# element toward zero by at most delta. -class DeltaDropout(nn.Module): - def __init__(self, p: FloatLike, delta: float): - super().__init__() - self.p = p - self.delta = delta - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - - rand = (torch.rand_like(x) < p) * x.abs().clamp_(max=self.delta) * -x.sgn() - return x + rand - class SwooshLFunction(torch.autograd.Function): """ diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1c33a3061b..7fa1647c2b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -35,7 +35,6 @@ BiasNorm, ChunkCausalDepthwiseConv1d, Dropout2, - DeltaDropout, FloatLike, ScheduledFloat, Whiten, @@ -518,10 +517,6 @@ def __init__( self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - # This is supposed to encourage the scale of the embeddings to get larger if it is too small. - self.dropout_ff2 = DeltaDropout(0.1, delta=0.01) - self.conv_module = ConvolutionModule( embed_dim, cnn_module_kernel, causal=causal ) @@ -572,7 +567,7 @@ def forward( src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask ) - src = src + self.dropout_ff2(self.feed_forward2(src)) + src = src + self.feed_forward2(src) src = self.bypass(src_orig, src) From a78cab0607a4555c744d23c186d61c7b7e6bfacb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 15:11:56 +0800 Subject: [PATCH 0122/1191] Reduce weight_min_rms from 0.01 to 0.002 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 2cd3fa44ee..ba914f9bf1 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -314,7 +314,7 @@ def __init__( betas=(0.9, 0.98), scalar_lr_scale=0.1, eps=1.0e-08, - weight_min_rms=0.01, + weight_min_rms=0.002, weight_max_rms=1.0, bias_min_rms=1.0e-05, bias_max_rms=3.0, From 80c855828130bc29ecb8d14604531e38db866db9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 15:14:28 +0800 Subject: [PATCH 0123/1191] Add a schedule for zipformer layer scale_limiter, start at 2.0 and reduce to 0.5 after 10k iters --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7fa1647c2b..e0049607fe 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -521,7 +521,7 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) - self.scale_limiter = ScaleLimiter(max_scale=0.5) + self.scale_limiter = ScaleLimiter(max_scale=ScheduledFloat((0.0, 2.0), (10000.0, 0.5), default=2.0)) self.norm = BiasNorm(embed_dim) From 374e78a4c69c767140806b17a1b16ea07acba2ce Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 17:13:41 +0800 Subject: [PATCH 0124/1191] Remove lr_scale=0.5 in OrthogonalLinear --- egs/librispeech/ASR/zipformer/scaling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6e2ff5481a..113315ca9a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -567,7 +567,6 @@ def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): self.penalty_scale = copy.deepcopy(penalty_scale) self.min_product_scale = 0.01 self.name = None # will be set from training loop. for printing penalty. - self.lr_scale = 0.5 with torch.no_grad(): # this is not orthogonal but should quickly become so. From 78148f7da4c00ee7c9d18377ad6f9ee8a1687656 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 20:57:51 +0800 Subject: [PATCH 0125/1191] Scale up initialization of encoder_embed.out and feedforward in_proj by factor of 2. --- egs/librispeech/ASR/zipformer/subsampling.py | 5 ++++- egs/librispeech/ASR/zipformer/zipformer.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index c008720ac8..e7f7f7cbea 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -23,6 +23,7 @@ from scaling import ( Balancer, ScaleLimiter, + ScaledLinear, BiasNorm, Dropout3, FloatLike, @@ -249,7 +250,9 @@ def __init__( self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.layer3_channels = layer3_channels - self.out = nn.Linear(self.out_width * layer3_channels, out_channels) + # scale it up a bit, else the output is quite small. + self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, + initial_scale=2.0) self.out_limiter = ScaleLimiter(max_scale=0.5) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e0049607fe..df9255679e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1626,7 +1626,9 @@ class FeedforwardModule(nn.Module): def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim, + initial_scale=2.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( From 9f7ce9d64af7ec44f7cecb1bc6c318f7662a3cd0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Feb 2025 23:34:47 +0800 Subject: [PATCH 0126/1191] set size_update_period=1 --- egs/librispeech/ASR/zipformer/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f4344d8ccc..f008bed0c0 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1341,6 +1341,9 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, + size_update_period=1, # for some reason, setting this to 1 (default is + # 4) seems to stop the embeddings from getting too + # small. ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, From af42dccb76139b6f28ab48c3b449768aac50df66 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 11 Feb 2025 12:03:38 +0800 Subject: [PATCH 0127/1191] Increase weight_min_rms from 0.002 to 0.005. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ba914f9bf1..d76f6f6899 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -314,7 +314,7 @@ def __init__( betas=(0.9, 0.98), scalar_lr_scale=0.1, eps=1.0e-08, - weight_min_rms=0.002, + weight_min_rms=0.005, weight_max_rms=1.0, bias_min_rms=1.0e-05, bias_max_rms=3.0, From b3b48545b238be5630b0d831dbd17adff0912d89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 11 Feb 2025 22:24:40 +0800 Subject: [PATCH 0128/1191] Add the + scale**2 term which is a 2nd order taylor expansion term intended to avoid a certain bias in the update, towards small scales --- egs/librispeech/ASR/zipformer/optim.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d76f6f6899..d5cd61a6a4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -235,7 +235,12 @@ def scaling_step(group, p, state, grad): # topology or settings. scale_step = torch.minimum(scale_step, (max_rms - param_rms) / param_rms) - delta.add_(p * scale_step) + # the "+ 0.5 * scale_step ** 2" can be thought of as taking the second + # term in the Taylor expansion of exp(s) - 1, which is s + s^2 / 2!. + # this is so that in effect we are learning the scale in log space, + # so to represent it in p we have to exponentiate it. it's to avoid + # a downward bias in the scale that might otherwise happen. + delta.add_(p * (scale_step + 0.5 * scale_step ** 2)) return delta From 711dbc91597a41fc309604cdafd08fe7a3261768 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Feb 2025 10:22:39 +0800 Subject: [PATCH 0129/1191] try to stop params getting clamped at too-small values. --- egs/librispeech/ASR/zipformer/optim.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d5cd61a6a4..804b725ef4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -220,10 +220,14 @@ def scaling_step(group, p, state, grad): -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom ) - is_too_small = param_rms < min_rms + not_too_small = param_rms > min_rms + + # when the param gets too small, don't shrink it any further. + # that means we set it to zero if it was negative. + # -not_too_small.to(p.dtype) is 0 if it is too small, and -1 if it + # is not too small which will anyway be below the step. + scale_step = torch.maximum(scale_step, -not_too_small.to(p.dtype)) - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) # The following may help prevent instability: don't allow the scale step to be too large in # either direction. From 67ac5c030ddb47a99f019a7bbbca505c39bc6408 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Feb 2025 16:36:36 +0800 Subject: [PATCH 0130/1191] scale_step_factor = ((param_rms / min_rms) - 1.).clamp_(min=0.0, max=1.0), do not update at all as we approach param_min_rms. --- egs/librispeech/ASR/zipformer/optim.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 804b725ef4..2858c8d973 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -220,17 +220,13 @@ def scaling_step(group, p, state, grad): -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom ) - not_too_small = param_rms > min_rms - - # when the param gets too small, don't shrink it any further. - # that means we set it to zero if it was negative. - # -not_too_small.to(p.dtype) is 0 if it is too small, and -1 if it - # is not too small which will anyway be below the step. - scale_step = torch.maximum(scale_step, -not_too_small.to(p.dtype)) - + # turn off the scale-step once param_rms is below min_rms, scale becomes + # 1.0 once we are twice param_min_rms. + scale_step_factor = ((param_rms / min_rms) - 1.).clamp_(min=0.0, max=1.0) # The following may help prevent instability: don't allow the scale step to be too large in # either direction. + # TODO: remove this. scale_step.clamp_(min=-0.1, max=0.1) # and ensure the parameter rms after update never exceeds max_rms. @@ -244,7 +240,7 @@ def scaling_step(group, p, state, grad): # this is so that in effect we are learning the scale in log space, # so to represent it in p we have to exponentiate it. it's to avoid # a downward bias in the scale that might otherwise happen. - delta.add_(p * (scale_step + 0.5 * scale_step ** 2)) + delta.add_(p * (scale_step_factor * (scale_step + 0.5 * scale_step ** 2))) return delta From fde6b01d29fe87943b3ce19c1a098459f7a799ed Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Feb 2025 18:56:18 +0800 Subject: [PATCH 0131/1191] Fix interaction with exponential taylor expansion with scale_step_factor --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 2858c8d973..83baa38e41 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -233,14 +233,14 @@ def scaling_step(group, p, state, grad): # We have to look at the trained model for parameters at or around the # max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, (max_rms - param_rms) / param_rms) + scale_step = scale_step_factor * torch.minimum(scale_step, (max_rms - param_rms) / param_rms) # the "+ 0.5 * scale_step ** 2" can be thought of as taking the second # term in the Taylor expansion of exp(s) - 1, which is s + s^2 / 2!. # this is so that in effect we are learning the scale in log space, # so to represent it in p we have to exponentiate it. it's to avoid # a downward bias in the scale that might otherwise happen. - delta.add_(p * (scale_step_factor * (scale_step + 0.5 * scale_step ** 2))) + delta.add_(p * (scale_step + 0.5 * scale_step ** 2)) return delta From c34aba1529c13122f45bbe7e358a4637107d2022 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Feb 2025 19:53:02 +0800 Subject: [PATCH 0132/1191] Increase initial_scale of in_proj of FeedforwardModule from 2 to 4, and dropout from 0.0 to ScheduledFloat((0.0, 0.4), (3000.0, 0.0)) --- egs/librispeech/ASR/zipformer/train.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f008bed0c0..82d4bb6a2c 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -653,7 +653,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=0.0, + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), warmup_batches=4000.0, causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index df9255679e..60795ce384 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1628,7 +1628,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=2.0) + initial_scale=4.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( From c463432841ef8f8f4d8c790f47998a34677759f6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Feb 2025 15:01:43 +0800 Subject: [PATCH 0133/1191] Increase weight_min_rms from 0.005 to 0.01 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 83baa38e41..9febfb1ea2 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -319,7 +319,7 @@ def __init__( betas=(0.9, 0.98), scalar_lr_scale=0.1, eps=1.0e-08, - weight_min_rms=0.005, + weight_min_rms=0.01, weight_max_rms=1.0, bias_min_rms=1.0e-05, bias_max_rms=3.0, From 59d89a9a0cb5d12f46e07601b2684101d1e96f50 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Feb 2025 17:17:02 +0800 Subject: [PATCH 0134/1191] Include in scaling_step a correction term that stops the parameter size from increasing due to gradient noise. --- egs/librispeech/ASR/zipformer/optim.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9febfb1ea2..7ff5c0b0fc 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -233,7 +233,16 @@ def scaling_step(group, p, state, grad): # We have to look at the trained model for parameters at or around the # max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = scale_step_factor * torch.minimum(scale_step, (max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (max_rms - param_rms) / param_rms) + + + # (1 + lr**2) ** 0.5 ~ 1 + (0.5 lr**2) would be the factor by which the parameter rms + # increases on each step, assuming the gradient is orthogonal to the current + # parameter value. we cancel this out by subtracting (0.5 * lr**2); we + # need to do this times size_update_period. + scale_step = scale_step - (0.5 * (group["lr"] ** 2) * size_update_period) + + scale_step = scale_step_factor * scale_step # the "+ 0.5 * scale_step ** 2" can be thought of as taking the second # term in the Taylor expansion of exp(s) - 1, which is s + s^2 / 2!. From 447aa62bc8c03980df7c70e62932ccb62da4b5b6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Feb 2025 22:13:27 +0800 Subject: [PATCH 0135/1191] Add debug_interval option, tensorboard support --- egs/librispeech/ASR/zipformer/optim.py | 45 ++++++++++++++++++++++++-- egs/librispeech/ASR/zipformer/train.py | 11 ++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9febfb1ea2..6f50034c66 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -261,6 +261,42 @@ def momentum_step(group, p, state, grad): return stored_delta +def debug_step(group, p, state, grad, param_names, summary_writer): + delta = momentum_step(group, p, state, grad) + debug_interval = group["debug_interval"] + step = state["step"] % debug_interval + + if step % debug_interval != 0 or summary_writer is None: + return delta + + debug_info = torch.zeros(p.shape[0], 6, device=p.device, dtype=torch.float) + + is_scalar = (p.numel() == p.shape[0]) + dims = list(range(1, p.ndim)) # e.g. dims to average. + + def maybe_rms(x): + if is_scalar: + # the .mean() is just to get rid of those dims. + return x.mean(dim=dims) + else: + return (x ** 2).mean(dim=dims).sqrt() + + debug_info[:, 0] = maybe_rms(p) + debug_info[:, 1] = maybe_rms(grad) + debug_info[:, 2] = maybe_rms(delta) + debug_info[:, 3] = (p * grad).sum(dim=dims) + debug_info[:, 4] = (p * delta).sum(dim=dims) + debug_info[:, 5] = (grad * delta).sum(dim=dims) + debug_info = debug_info.to('cpu') + + assert len(param_names) == p.shape[0] + for name, info in param_names, debug_info.unbind(dim=0): + for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms', 'param_grad', 'param_delta', 'grad_delta']): + summary_writer.add_scalar(f"debug/{legend}/{name}", step, info[i].item()) + + return delta + + class ScaledAdam(BatchedOptimizer): """ @@ -309,6 +345,7 @@ class ScaledAdam(BatchedOptimizer): of the parameter tensor. This is provided to save a little time in the update. clipping_update_period: if clipping_scale is specified, this is the period + debug_interval: if >0, write some statistics to tensorboard every this-many steps. """ def __init__( @@ -326,6 +363,7 @@ def __init__( scalar_max=10.0, size_update_period=4, clipping_update_period=100, + debug_interval=0, ): defaults = dict( @@ -341,6 +379,7 @@ def __init__( scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, + debug_interval=debug_interval, ) # If params only contains parameters or group of parameters, @@ -463,7 +502,7 @@ def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @torch.no_grad() - def step(self, closure=None): + def step(self, closure=None, summary_writer=None): """Performs a single optimization step. Arguments: @@ -492,7 +531,7 @@ def step(self, closure=None): else: clipping_scale = self._get_clipping_scale(group, batches) - for p, state, _ in batches: + for p, state, names in batches: # Perform optimization step. # grad is not going to be None, we handled that when creating the batches. grad = p.grad @@ -509,7 +548,7 @@ def step(self, closure=None): cur_step = 0 grad = (p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)) - p += momentum_step(group, p.detach(), state, grad) + p += debug_step(group, p.detach(), state, grad, names, summary_writer) if p.numel() == p.shape[0]: # scalar parameter scalar_max = group["scalar_max"] diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 82d4bb6a2c..829a9fa54f 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -356,6 +356,15 @@ def get_parser(): """, ) + parser.add_argument( + "--debug-interval", + type=int, + default=0, + help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for + finding parts of the network that are diverging or not well trained. + """ + ) + parser.add_argument( "--exp-dir", type=str, @@ -1122,7 +1131,7 @@ def save_bad_model(suffix: str = ""): scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) + scaler.step(optimizer, summary_writer=tb_writer) scaler.update() optimizer.zero_grad() except Exception as e: From 8e0725a8736ef128cd2b988ced666540330bb9ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Feb 2025 22:28:37 +0800 Subject: [PATCH 0136/1191] Bug fixes --- egs/librispeech/ASR/zipformer/optim.py | 16 +++++++++++++++- egs/librispeech/ASR/zipformer/train.py | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3cc8573f0e..5f46635c38 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -275,7 +275,7 @@ def debug_step(group, p, state, grad, param_names, summary_writer): debug_interval = group["debug_interval"] step = state["step"] % debug_interval - if step % debug_interval != 0 or summary_writer is None: + if debug_interval == 0 or step % debug_interval != 0 or summary_writer is None: return delta debug_info = torch.zeros(p.shape[0], 6, device=p.device, dtype=torch.float) @@ -307,6 +307,15 @@ def maybe_rms(x): +def _load_state_dict_pre_hook(optim: ScaledAdam, state_dict: dict): + for optim_group, load_group in zip(optim.param_groups, state_dict['param_groups']): + for key in ['debug_interval']: + try: + optim_group[key] = load_group[key] + logging.info(f"Copied key {key}") + except KeyError: + logging.info(f"Could not copy key {key} from optim state-dict.") + class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -400,6 +409,11 @@ def __init__( assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names + + self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook) + + + def _get_names_of_parameters( self, params_or_named_params ) -> Tuple[List[Dict], List[List[str]]]: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 829a9fa54f..965308c975 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1353,6 +1353,7 @@ def run(rank, world_size, args): size_update_period=1, # for some reason, setting this to 1 (default is # 4) seems to stop the embeddings from getting too # small. + debug_interval=params.debug_interval, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, From 0e80775d4a176f7ced2f5c9e7e42dd52d0e3b0cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 11:31:38 +0800 Subject: [PATCH 0137/1191] Bug fixes --- egs/librispeech/ASR/zipformer/optim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 5f46635c38..8dd9f9aac9 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -273,7 +273,7 @@ def momentum_step(group, p, state, grad): def debug_step(group, p, state, grad, param_names, summary_writer): delta = momentum_step(group, p, state, grad) debug_interval = group["debug_interval"] - step = state["step"] % debug_interval + step = state["step"] if debug_interval == 0 or step % debug_interval != 0 or summary_writer is None: return delta @@ -299,15 +299,15 @@ def maybe_rms(x): debug_info = debug_info.to('cpu') assert len(param_names) == p.shape[0] - for name, info in param_names, debug_info.unbind(dim=0): + for name, info in zip(param_names, debug_info.unbind(dim=0)): for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms', 'param_grad', 'param_delta', 'grad_delta']): - summary_writer.add_scalar(f"debug/{legend}/{name}", step, info[i].item()) + summary_writer.add_scalar(f"debug/{legend}/{name}", info[i].item(), step) return delta -def _load_state_dict_pre_hook(optim: ScaledAdam, state_dict: dict): +def _load_state_dict_pre_hook(optim: Optimizer, state_dict: dict): for optim_group, load_group in zip(optim.param_groups, state_dict['param_groups']): for key in ['debug_interval']: try: From 6dbc0643e9291f267a746b17297d94b3214afe31 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 11:39:17 +0800 Subject: [PATCH 0138/1191] Remove size_update_period=1, revert to default of 4. --- egs/librispeech/ASR/zipformer/train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 82d4bb6a2c..189f6aa7ba 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1341,9 +1341,6 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, - size_update_period=1, # for some reason, setting this to 1 (default is - # 4) seems to stop the embeddings from getting too - # small. ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, From f8c3ea876f2adf9069189697f79082caccaa17bf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 14:03:53 +0800 Subject: [PATCH 0139/1191] Do not save debug information until failure. --- egs/librispeech/ASR/zipformer/optim.py | 83 ++++++++++++++++++++++---- egs/librispeech/ASR/zipformer/train.py | 5 +- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8dd9f9aac9..2556e04a18 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -270,19 +270,36 @@ def momentum_step(group, p, state, grad): return stored_delta -def debug_step(group, p, state, grad, param_names, summary_writer): - delta = momentum_step(group, p, state, grad) + +def debug_step(group, p, state, grad): debug_interval = group["debug_interval"] + debug_buffer_size = 256 step = state["step"] if debug_interval == 0 or step % debug_interval != 0 or summary_writer is None: + delta = momentum_step(group, p, state, grad) return delta - debug_info = torch.zeros(p.shape[0], 6, device=p.device, dtype=torch.float) - is_scalar = (p.numel() == p.shape[0]) dims = list(range(1, p.ndim)) # e.g. dims to average. + try: + old_delta = state["delta"] + grad_old_delta = (grad * old_delta).sum(dim=dims) + except KeyError: + grad_old_delta = 0.0 + + delta = momentum_step(group, p, state, grad) + + try: + debug_info = state["debug_info"] + except KeyError: + debug_info = torch.zeros(debug_buffer_size, p.shape[0], 6, + device=p.device, dtype=torch.float) + state["debug_info"] = debug_info + + is_scalar = (p.numel() == p.shape[0]) + def maybe_rms(x): if is_scalar: # the .mean() is just to get rid of those dims. @@ -290,20 +307,53 @@ def maybe_rms(x): else: return (x ** 2).mean(dim=dims).sqrt() + + debug_info = debug_info[(step // debug_interval) % debug_buffer_size] + debug_info[:, 0] = maybe_rms(p) debug_info[:, 1] = maybe_rms(grad) debug_info[:, 2] = maybe_rms(delta) debug_info[:, 3] = (p * grad).sum(dim=dims) debug_info[:, 4] = (p * delta).sum(dim=dims) - debug_info[:, 5] = (grad * delta).sum(dim=dims) + debug_info[:, 5] = grad_old_delta + + return delta + + +def _write_debug_info(group, state, param_names, summary_writer): + """ + Writes to a Tensorboard, model-debugging information that was accumulated in debug_step. + """ + cur_step = state["step"] + debug_interval = group["debug_interval"] + + try: + debug_info = state["debug_info"] + except KeyError: + return + + (debug_buffer_size, num_params, _six) = debug_info.shape + + # cur_index would be where the next debug_info would go in the buffer + cur_index = (cur_step // debug_interval) % debug_buffer_size + # roll the data in the buffer so that cur_index goes to position zero. + debug_info = torch.roll(debug_info, -cur_index, 0, 0) + + debug_info = debug_info.to('cpu') - assert len(param_names) == p.shape[0] - for name, info in zip(param_names, debug_info.unbind(dim=0)): - for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms', 'param_grad', 'param_delta', 'grad_delta']): - summary_writer.add_scalar(f"debug/{legend}/{name}", info[i].item(), step) + assert len(param_names) == num_params + + for step in range(debug_buffer_size): + # this formula for real_step is rather approximate, it doesn't properly + # account for end effetcs, or missed steps in amp mode due to infinities. + real_step = debug_interval * (step - debug_buffer_size) + cur_step + + for name, info in zip(param_names, debug_info[step].unbind(dim=0)): + for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms', 'param_grad', 'param_delta', 'grad_delta']): + summary_writer.add_scalar(f"debug/{legend}/{name}", info[i].item(), real_step) + - return delta @@ -525,7 +575,7 @@ def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @torch.no_grad() - def step(self, closure=None, summary_writer=None): + def step(self, closure=None): """Performs a single optimization step. Arguments: @@ -554,7 +604,7 @@ def step(self, closure=None, summary_writer=None): else: clipping_scale = self._get_clipping_scale(group, batches) - for p, state, names in batches: + for p, state, _names in batches: # Perform optimization step. # grad is not going to be None, we handled that when creating the batches. grad = p.grad @@ -571,7 +621,7 @@ def step(self, closure=None, summary_writer=None): cur_step = 0 grad = (p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)) - p += debug_step(group, p.detach(), state, grad, names, summary_writer) + p += debug_step(group, p.detach(), state, grad) if p.numel() == p.shape[0]: # scalar parameter scalar_max = group["scalar_max"] @@ -582,6 +632,13 @@ def step(self, closure=None, summary_writer=None): return loss + @torch.no_grad() + def write_debug_info(self, summary_writer): + for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + for _p, state, names in batches: + _write_debug_info(group, state, names, summary_writer) + def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 965308c975..ac19bb54c2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1131,11 +1131,14 @@ def save_bad_model(suffix: str = ""): scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer, summary_writer=tb_writer) + scaler.step(optimizer) scaler.update() optimizer.zero_grad() except Exception as e: logging.info(f"Caught exception: {e}.") + if params.debug_interval > 0: + logging.info("Writing debug info to tensorboard.") + scaler.write_debug_info(summary_writer=tb_writer) save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise From 9cff3f1e20b8992dc18992458deb78a3954bb8cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 14:46:47 +0800 Subject: [PATCH 0140/1191] Bug fixes --- egs/librispeech/ASR/zipformer/optim.py | 4 +--- egs/librispeech/ASR/zipformer/train.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 2556e04a18..3ddcdb75f1 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -276,11 +276,10 @@ def debug_step(group, p, state, grad): debug_buffer_size = 256 step = state["step"] - if debug_interval == 0 or step % debug_interval != 0 or summary_writer is None: + if debug_interval == 0 or step % debug_interval != 0: delta = momentum_step(group, p, state, grad) return delta - dims = list(range(1, p.ndim)) # e.g. dims to average. try: @@ -339,7 +338,6 @@ def _write_debug_info(group, state, param_names, summary_writer): # roll the data in the buffer so that cur_index goes to position zero. debug_info = torch.roll(debug_info, -cur_index, 0, 0) - debug_info = debug_info.to('cpu') assert len(param_names) == num_params diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index ac19bb54c2..5578bdd5a6 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1138,7 +1138,7 @@ def save_bad_model(suffix: str = ""): logging.info(f"Caught exception: {e}.") if params.debug_interval > 0: logging.info("Writing debug info to tensorboard.") - scaler.write_debug_info(summary_writer=tb_writer) + optimizer.write_debug_info(summary_writer=tb_writer) save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise From 9516800b08c80fecd52a39ae2fb294b1e38173d7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 15:48:55 +0800 Subject: [PATCH 0141/1191] Reduce correction_factor from 0.5 to 0.4, to correct for parameter growth. --- egs/librispeech/ASR/zipformer/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7ff5c0b0fc..315e6b832f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -240,7 +240,9 @@ def scaling_step(group, p, state, grad): # increases on each step, assuming the gradient is orthogonal to the current # parameter value. we cancel this out by subtracting (0.5 * lr**2); we # need to do this times size_update_period. - scale_step = scale_step - (0.5 * (group["lr"] ** 2) * size_update_period) + + CORRECTION_FACTOR = 0.4 # mathematically this should be 0.5 + scale_step = scale_step - (CORRECTION_FACTOR * (group["lr"] ** 2) * size_update_period) scale_step = scale_step_factor * scale_step From 55a62c04c1f50756046a3a5b514f75aad27dfa32 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 16:42:47 +0800 Subject: [PATCH 0142/1191] Reduce lower limit of eps from 1.0 to 0.5. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 113315ca9a..3193e11195 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -474,7 +474,7 @@ def forward(self, x: Tensor) -> Tensor: return (x * scales) eps = limit_param_value( - self.eps, min=1.0, max=2.0, training=self.training) + self.eps, min=0.5, max=2.0, training=self.training) power = limit_param_value( self.power, min=0.25, max=2.0, training=self.training) From f1ed4789c27b6760ef91d6e51686a131827e8b5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 16:46:16 +0800 Subject: [PATCH 0143/1191] Increase weight_max_rms from 1.0 to 1.5. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7ff5c0b0fc..a9e2efe367 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -329,7 +329,7 @@ def __init__( scalar_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.01, - weight_max_rms=1.0, + weight_max_rms=1.5, bias_min_rms=1.0e-05, bias_max_rms=3.0, scalar_max=10.0, From 06cc81a63660b7a893ad6142f7fb0736aec4b384 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 16:52:05 +0800 Subject: [PATCH 0144/1191] Reduce weight_min_rms from 0.01 to 0.005. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7ff5c0b0fc..062f67313a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -328,7 +328,7 @@ def __init__( betas=(0.9, 0.98), scalar_lr_scale=0.1, eps=1.0e-08, - weight_min_rms=0.01, + weight_min_rms=0.005, weight_max_rms=1.0, bias_min_rms=1.0e-05, bias_max_rms=3.0, From af384428604b9a1c29e4504243b18b7f33745659 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 21:00:09 +0800 Subject: [PATCH 0145/1191] Bug fix RE when to write debug info --- egs/librispeech/ASR/zipformer/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 5578bdd5a6..eb615fe1c7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1092,6 +1092,9 @@ def train_one_epoch( saved_bad_model = False def save_bad_model(suffix: str = ""): + if params.debug_interval > 0: + logging.info("Writing debug info to tensorboard.") + optimizer.write_debug_info(summary_writer=tb_writer) save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, @@ -1136,9 +1139,6 @@ def save_bad_model(suffix: str = ""): optimizer.zero_grad() except Exception as e: logging.info(f"Caught exception: {e}.") - if params.debug_interval > 0: - logging.info("Writing debug info to tensorboard.") - optimizer.write_debug_info(summary_writer=tb_writer) save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise From 55dca0b2d50cf132a751f15e33f017db51684eb6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Feb 2025 22:24:37 +0800 Subject: [PATCH 0146/1191] Bug fix --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3ddcdb75f1..ef70f5d722 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -336,7 +336,7 @@ def _write_debug_info(group, state, param_names, summary_writer): # cur_index would be where the next debug_info would go in the buffer cur_index = (cur_step // debug_interval) % debug_buffer_size # roll the data in the buffer so that cur_index goes to position zero. - debug_info = torch.roll(debug_info, -cur_index, 0, 0) + debug_info = torch.roll(debug_info, -cur_index, 0) debug_info = debug_info.to('cpu') From 271c2a6a32a565c1bccc0ac498f0e5cec2731233 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 13:17:14 +0800 Subject: [PATCH 0147/1191] Reduce CORRECTION_FACTOR from 0.4 to 0.25. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 315e6b832f..83ab30967f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -241,7 +241,7 @@ def scaling_step(group, p, state, grad): # parameter value. we cancel this out by subtracting (0.5 * lr**2); we # need to do this times size_update_period. - CORRECTION_FACTOR = 0.4 # mathematically this should be 0.5 + CORRECTION_FACTOR = 0.25 # mathematically this should be 0.5 scale_step = scale_step - (CORRECTION_FACTOR * (group["lr"] ** 2) * size_update_period) scale_step = scale_step_factor * scale_step From 3f470cfd382e0bac83dfe5662bcc5de2bfab71fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 17:24:24 +0800 Subject: [PATCH 0148/1191] CORRECTION_FACTOR = 0.25 if is_weight else 0.5, more shrinkage for biases. --- egs/librispeech/ASR/zipformer/optim.py | 7 +++++-- egs/librispeech/ASR/zipformer/train.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c673ba2f25..bdaaa6b3f9 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -188,7 +188,8 @@ def scaling_step(group, p, state, grad): ) # would be p.ndim > 1 not p.ndim > 2 but one dim is for batch of tensors. - min_rms = group["weight_min_rms"] if p.ndim > 2 else group["bias_min_rms"] + is_weight = (p.ndim > 2) + min_rms = group["weight_min_rms"] if is_weight else group["bias_min_rms"] # scale the step size by param_rms. This is the most important "scaling" part of # ScaledAdam @@ -241,7 +242,9 @@ def scaling_step(group, p, state, grad): # parameter value. we cancel this out by subtracting (0.5 * lr**2); we # need to do this times size_update_period. - CORRECTION_FACTOR = 0.25 # mathematically this should be 0.5 + CORRECTION_FACTOR = 0.25 if is_weight else 0.5 + # mathematically this should be 0.5. 0.25 gives less-aggressive shrinkage. give the more-aggressive shrinkage + # of 0.5 for biases, as the biases getting relatively smaller will tend to prevent failure of the grad to propagate. scale_step = scale_step - (CORRECTION_FACTOR * (group["lr"] ** 2) * size_update_period) scale_step = scale_step_factor * scale_step diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index eb615fe1c7..c73b6942ab 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -359,7 +359,7 @@ def get_parser(): parser.add_argument( "--debug-interval", type=int, - default=0, + default=10, help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for finding parts of the network that are diverging or not well trained. """ From 4112ab58e3b0dfd41477c54dcad4fa2b17605a17 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 19:03:30 +0800 Subject: [PATCH 0149/1191] Init biases ten times smaller in ScaledLinear --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3193e11195..234d2372cc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -511,7 +511,8 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + ans.bias[:] = 0.0 + torch.nn.init.uniform_(ans.bias, -0.01 * initial_scale, 0.01 * initial_scale) return ans From 90e905bb712b94dcc34612654eb744fb87e0bfe2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 19:15:27 +0800 Subject: [PATCH 0150/1191] Add --dump-debug-interval option --- egs/librispeech/ASR/zipformer/optim.py | 1 + egs/librispeech/ASR/zipformer/train.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index bdaaa6b3f9..71869f82c7 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -637,6 +637,7 @@ def step(self, closure=None): @torch.no_grad() def write_debug_info(self, summary_writer): + logging.info("Writing debug info to tensorboard.") for group, group_params_names in zip(self.param_groups, self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches: for _p, state, names in batches: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c73b6942ab..d099d4629f 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -365,6 +365,15 @@ def get_parser(): """ ) + parser.add_argument( + "--dump-debug-interval", + type=int, + default=0, + help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they + are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + """ + ) + parser.add_argument( "--exp-dir", type=str, @@ -1093,7 +1102,6 @@ def train_one_epoch( def save_bad_model(suffix: str = ""): if params.debug_interval > 0: - logging.info("Writing debug info to tensorboard.") optimizer.write_debug_info(summary_writer=tb_writer) save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", @@ -1247,6 +1255,10 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) + if params.batch_idx_train > 0 and params.batch_idx_train % params.dump_debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) + + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: From 1b6a90a07cf7680500d3ac9c34a51e7194aa14a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 22:01:10 +0800 Subject: [PATCH 0151/1191] Various bug fixes etc., for debug-writing --- egs/librispeech/ASR/zipformer/optim.py | 40 +++++++++----------------- egs/librispeech/ASR/zipformer/train.py | 6 ++-- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 71869f82c7..35c8a7482b 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -281,29 +281,21 @@ def debug_step(group, p, state, grad): debug_buffer_size = 256 step = state["step"] + delta = momentum_step(group, p, state, grad) + if debug_interval == 0 or step % debug_interval != 0: - delta = momentum_step(group, p, state, grad) return delta - dims = list(range(1, p.ndim)) # e.g. dims to average. - - try: - old_delta = state["delta"] - grad_old_delta = (grad * old_delta).sum(dim=dims) - except KeyError: - grad_old_delta = 0.0 - - delta = momentum_step(group, p, state, grad) - try: debug_info = state["debug_info"] except KeyError: - debug_info = torch.zeros(debug_buffer_size, p.shape[0], 6, + debug_info = torch.zeros(debug_buffer_size, p.shape[0], 2, device=p.device, dtype=torch.float) state["debug_info"] = debug_info is_scalar = (p.numel() == p.shape[0]) + dims = list(range(1, p.ndim)) # e.g. dims to average. def maybe_rms(x): if is_scalar: # the .mean() is just to get rid of those dims. @@ -316,10 +308,6 @@ def maybe_rms(x): debug_info[:, 0] = maybe_rms(p) debug_info[:, 1] = maybe_rms(grad) - debug_info[:, 2] = maybe_rms(delta) - debug_info[:, 3] = (p * grad).sum(dim=dims) - debug_info[:, 4] = (p * delta).sum(dim=dims) - debug_info[:, 5] = grad_old_delta return delta @@ -336,7 +324,7 @@ def _write_debug_info(group, state, param_names, summary_writer): except KeyError: return - (debug_buffer_size, num_params, _six) = debug_info.shape + (debug_buffer_size, num_params, _two) = debug_info.shape # cur_index would be where the next debug_info would go in the buffer cur_index = (cur_step // debug_interval) % debug_buffer_size @@ -347,16 +335,14 @@ def _write_debug_info(group, state, param_names, summary_writer): assert len(param_names) == num_params - for step in range(debug_buffer_size): - # this formula for real_step is rather approximate, it doesn't properly - # account for end effetcs, or missed steps in amp mode due to infinities. - real_step = debug_interval * (step - debug_buffer_size) + cur_step - - for name, info in zip(param_names, debug_info[step].unbind(dim=0)): - for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms', 'param_grad', 'param_delta', 'grad_delta']): - summary_writer.add_scalar(f"debug/{legend}/{name}", info[i].item(), real_step) - + arange = torch.arange(debug_buffer_size) + steps = debug_interval * (arange - debug_buffer_size) + cur_step + for i, legend in enumerate(['param_rms', 'grad_rms']): + for name, info in zip(param_names, debug_info[..., i].unbind(dim=1)): + debug_str = f"debug/{legend}/{name}" + for step, value in zip(steps.tolist(), info.tolist()): + summary_writer.add_scalar(debug_str, value, step) @@ -637,6 +623,8 @@ def step(self, closure=None): @torch.no_grad() def write_debug_info(self, summary_writer): + if summary_writer is None: + return logging.info("Writing debug info to tensorboard.") for group, group_params_names in zip(self.param_groups, self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index d099d4629f..6e8a191e26 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1255,10 +1255,9 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - if params.batch_idx_train > 0 and params.batch_idx_train % params.dump_debug_interval > 0: + if params.batch_idx_train > 0 and params.batch_idx_train % params.dump_debug_interval == 0: optimizer.write_debug_info(summary_writer=tb_writer) - loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1289,7 +1288,8 @@ def run(rank, world_size, args): logging.info("Training started") if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard", + max_queue=3000) else: tb_writer = None From fd4674c2a5fdf9a89ce891b7732c2d443f797408 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 23:08:38 +0800 Subject: [PATCH 0152/1191] Attempts to speed up debug_stats writing. --- egs/librispeech/ASR/zipformer/optim.py | 12 +++++++++++- egs/librispeech/ASR/zipformer/train.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 35c8a7482b..10ed2c4747 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -320,7 +320,7 @@ def _write_debug_info(group, state, param_names, summary_writer): debug_interval = group["debug_interval"] try: - debug_info = state["debug_info"] + debug_info = state["debug_info_cpu"] except KeyError: return @@ -626,10 +626,20 @@ def write_debug_info(self, summary_writer): if summary_writer is None: return logging.info("Writing debug info to tensorboard.") + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + for _p, state, names in batches: + try: + state["debug_info_cpu"] = state["debug_info"].to(device="cpu", non_blocking=True) + except: + pass + for group, group_params_names in zip(self.param_groups, self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches: for _p, state, names in batches: _write_debug_info(group, state, names, summary_writer) + del state["debug_info_cpu"] def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 6e8a191e26..7a7d4aade0 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1289,7 +1289,7 @@ def run(rank, world_size, args): if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard", - max_queue=3000) + max_queue=100000) else: tb_writer = None From 6677998db38ee883f5e751d3e1ce681ac56f9641 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Feb 2025 23:28:53 +0800 Subject: [PATCH 0153/1191] Fix speed issue of --dump-debug-stats --- egs/librispeech/ASR/zipformer/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7a7d4aade0..0f4395e06f 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1288,8 +1288,14 @@ def run(rank, world_size, args): logging.info("Training started") if args.tensorboard and rank == 0: + # the reason for the very large max_queue is this: if --dump-debug-interval is set, + # e.g. to 2560, every that-many batches we will dump a very large number of events + # to the writer. These are added to a queue that is drained raather slowly. + # If we make the max_queue large enough to include all the events from calling + # "optimizer.write_debug_info(), we can continue with training and let the + # background thread take care of dumping those events at its own speed. tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard", - max_queue=100000) + max_queue=10000000) else: tb_writer = None From fe0a211edb88c9f329ecb2d4b383839efad9e48f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Feb 2025 16:35:54 +0800 Subject: [PATCH 0154/1191] Change correction_factor from 0.25 to 0.35 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b5559bd285..ce5ba69abd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -242,7 +242,7 @@ def scaling_step(group, p, state, grad): # parameter value. we cancel this out by subtracting (0.5 * lr**2); we # need to do this times size_update_period. - CORRECTION_FACTOR = 0.25 if is_weight else 0.5 + CORRECTION_FACTOR = 0.35 if is_weight else 0.5 # mathematically this should be 0.5. 0.25 gives less-aggressive shrinkage. give the more-aggressive shrinkage # of 0.5 for biases, as the biases getting relatively smaller will tend to prevent failure of the grad to propagate. scale_step = scale_step - (CORRECTION_FACTOR * (group["lr"] ** 2) * size_update_period) From 6ccab03aee0e1dfc95750e84c44c96255d67d17c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Feb 2025 20:38:19 +0800 Subject: [PATCH 0155/1191] Add another self_attn, conv, feedforward module in zipformer layer; reduce initial_scale of FeedforwardModule from 4 to 3. --- egs/librispeech/ASR/zipformer/zipformer.py | 26 +++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 60795ce384..3cbcd65144 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -509,17 +509,17 @@ def __init__( dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout - ) + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.conv_module = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + + + self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + for _ in range(2) ] self.scale_limiter = ScaleLimiter(max_scale=ScheduledFloat((0.0, 2.0), (10000.0, 0.5), default=2.0)) @@ -563,12 +563,16 @@ def forward( src = src + self.self_attn1(src, attn_weights) - src = src + self.conv_module( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ) + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) + src = src + self.self_attn2(src, attn_weights) + + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward3(src) + src = self.bypass(src_orig, src) src = self.scale_limiter(src) @@ -1628,7 +1632,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=4.0) + initial_scale=3.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( From 251ac2841b30ad402bf5135556e1750bf5c8ea7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 11:04:46 +0800 Subject: [PATCH 0156/1191] Bug fix, prevent crash when dump-debug-interval == 0 --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 0f4395e06f..e32f018e77 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1255,7 +1255,7 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - if params.batch_idx_train > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: optimizer.write_debug_info(summary_writer=tb_writer) loss_value = tot_loss["loss"] / tot_loss["frames"] From aa7023184223803d238e4ca2e11c14d9fe0f1773 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 13:13:49 +0800 Subject: [PATCH 0157/1191] reduce scalar_lr_scale from 0.1 to 0.025; increase initial_scale of ff modules from 3.0 to 5.0. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ce5ba69abd..67f0691782 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -411,7 +411,7 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.1, + scalar_lr_scale=0.025, eps=1.0e-08, weight_min_rms=0.005, weight_max_rms=1.0, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3cbcd65144..4805f085a5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1632,7 +1632,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=3.0) + initial_scale=5.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( From 08d2f3c1514545f3b1da88594b2da977e1e13e0a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 13:16:07 +0800 Subject: [PATCH 0158/1191] Reduce scalar_lr_scale from 0.1 to 0.025; Increase initial_scale of feedforward in_proj from 4.0 to 5.0 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ce5ba69abd..67f0691782 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -411,7 +411,7 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.1, + scalar_lr_scale=0.025, eps=1.0e-08, weight_min_rms=0.005, weight_max_rms=1.0, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 60795ce384..6beb5bd122 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1628,7 +1628,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=4.0) + initial_scale=5.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( From 695eb5f288cb1b3ac93db4534886d6adc4ee7d93 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 13:29:16 +0800 Subject: [PATCH 0159/1191] Replace BiasNorm with LogNorm --- egs/librispeech/ASR/zipformer/scaling.py | 74 ++++++++++-------------- 1 file changed, 30 insertions(+), 44 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 234d2372cc..ab04cf5c42 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -363,7 +363,14 @@ def backward(ctx, x_grad, *args): return x_grad + x_extra_grad.detach(), None, None, None, None -class BiasNormFunction(torch.autograd.Function): +def _log_norm(x: Tensor, scale: Tensor, channel_dim: int): + x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() + scales = torch.log1p(x_norm) / x_norm + scales = torch.nan_to_num(scales, nan=1.0, posinf=1.0, neginf=1.0) + scales = scale * scales + return (x * scales) + +class LogNormFunction(torch.autograd.Function): # This computes: # scales = (torch.mean(x ** 2 + eps, keepdim=True)) ** -0.5 * log_scale.exp() # return x * scales @@ -374,41 +381,29 @@ class BiasNormFunction(torch.autograd.Function): def forward( ctx, x: Tensor, - eps: Tensor, - power: Tensor, scale: Tensor, channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim + ctx.save_for_backward(x, scale) + return _log_norm(x, scale, channel_dim) - x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = scale * (x_sq ** power + eps) ** (-0.5 / power) - ans = x * scales - ctx.save_for_backward( - x.detach(), - eps.detach(), - power.detach(), - scale.detach(), - ) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, eps, power, scale = ctx.saved_tensors + x, scale= ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): - x, power, eps, scale = x.to(torch.float32), power.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - x, power, eps, scale = x.detach(), power.detach(), eps.detach(), scale.detach() + x, scale = x.to(torch.float32), scale.to(torch.float32) + x, scale = x.detach(), scale.detach() x.requires_grad = True - eps.requires_grad = True - power.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - x_sq = torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) - scales = scale * (x_sq ** power + eps) ** (-0.5 / power) - ans = x * scales + ans = _log_norm(x, scale, ctx.channel_dim) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -416,11 +411,15 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(eps.grad), c(power.grad), c(scale.grad), None + return x.grad, c(scale.grad), None + class BiasNorm(torch.nn.Module): """ + Comment not up-to-date. This is LogNorm. Will change docs later if + promising. + This is intended to be a simpler, and hopefully cheaper, replacement for LayerNorm. The observation this is based on, is that Transformer-type networks, especially with pre-norm, sometimes seem to set one of the @@ -446,7 +445,6 @@ class BiasNorm(torch.nn.Module): log_scale_min: FloatLike, minimum allowed value of log_scale log_scale_max: FloatLike, maximum allowed value of log_scale """ - def __init__( self, num_channels: int, @@ -455,9 +453,7 @@ def __init__( super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(2.0)) - self.eps = nn.Parameter(torch.tensor(1.0)) - self.power = nn.Parameter(torch.tensor(1.0)) + self.scale = nn.Parameter(torch.tensor(1.0)) self.name = None @@ -466,31 +462,21 @@ def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels if torch.jit.is_scripting() or torch.jit.is_tracing(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - x_sq = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - scales = self.scale * (x_sq ** self.power + self.eps) ** (-0.5 / self.power) - return (x * scales) - - eps = limit_param_value( - self.eps, min=0.5, max=2.0, training=self.training) - - power = limit_param_value( - self.power, min=0.25, max=2.0, training=self.training) + return _log_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.5, max=4.0, training=self.training) + self.scale, min=0.5, max=2.0, training=self.training) - if random.random() < 0.002: - x_rms = (x ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, power={power.item()}, eps={eps.item()}, eps**(1/power)={(eps ** (1/power))}, scale={scale.item()}, (eps**(0.5/power))/x_rms={(eps**(0.5/power))/x_rms}") - - return BiasNormFunction.apply( - x, eps, power, scale, self.channel_dim, + ans = LogNormFunction.apply( + x, scale, self.channel_dim, ) + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}") + return ans def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ From 41a784f7929a5b3c30ef869c3ac4800130e34577 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 13:37:41 +0800 Subject: [PATCH 0160/1191] Initialize scale to e-1 --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ab04cf5c42..8022a04cb9 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -394,7 +394,7 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, scale= ctx.saved_tensors + x, scale = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): x, scale = x.to(torch.float32), scale.to(torch.float32) @@ -453,7 +453,7 @@ def __init__( super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(1.0)) + self.scale = nn.Parameter(torch.tensor(1.718281828)) self.name = None @@ -465,7 +465,7 @@ def forward(self, x: Tensor) -> Tensor: return _log_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.5, max=2.0, training=self.training) + self.scale, min=0.5, max=2.5, training=self.training) ans = LogNormFunction.apply( x, scale, self.channel_dim, From 499b3b7d2ae147eaf2eb5bec87cf8714366315f6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 15:23:01 +0800 Subject: [PATCH 0161/1191] Increase max_scale values from 0.5 to 4.0 --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index e7f7f7cbea..d9297c09be 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -254,7 +254,7 @@ def __init__( self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=2.0) - self.out_limiter = ScaleLimiter(max_scale=0.5) + self.out_limiter = ScaleLimiter(max_scale=4.0) # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6beb5bd122..ce9e753719 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -521,7 +521,7 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) - self.scale_limiter = ScaleLimiter(max_scale=ScheduledFloat((0.0, 2.0), (10000.0, 0.5), default=2.0)) + self.scale_limiter = ScaleLimiter(max_scale=4.0) self.norm = BiasNorm(embed_dim) From cbd361bef6bb8fd4939854205bab2a01ef1edc81 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 17:46:26 +0800 Subject: [PATCH 0162/1191] Change initial scales of in,out proj of feedforward to 5,0.1 to 10,0.05; change initial scale of encoder_embed.out from 2 to 4. --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index e7f7f7cbea..96586fd6b3 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -252,7 +252,7 @@ def __init__( # scale it up a bit, else the output is quite small. self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, - initial_scale=2.0) + initial_scale=4.0) self.out_limiter = ScaleLimiter(max_scale=0.5) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4805f085a5..87b8da9eef 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1632,7 +1632,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=5.0) + initial_scale=10.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( @@ -1642,7 +1642,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): dropout_p=dropout, dropout_shared_dim=0, bias=True, - initial_scale=0.1, + initial_scale=0.05, ) self.out_whiten = Whiten( From dca251884d71b3c104ea3c220c9903ef55e8160d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 19:36:49 +0800 Subject: [PATCH 0163/1191] Change LogNorm to ExpNorm. --- egs/librispeech/ASR/zipformer/scaling.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8022a04cb9..241587d8f0 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -363,14 +363,14 @@ def backward(ctx, x_grad, *args): return x_grad + x_extra_grad.detach(), None, None, None, None -def _log_norm(x: Tensor, scale: Tensor, channel_dim: int): + +def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - scales = torch.log1p(x_norm) / x_norm - scales = torch.nan_to_num(scales, nan=1.0, posinf=1.0, neginf=1.0) + scales = (1. - (-x_norm).exp()) / x_norm # torch.log1p(x_norm) / x_norm scales = scale * scales return (x * scales) -class LogNormFunction(torch.autograd.Function): +class ExpNormFunction(torch.autograd.Function): # This computes: # scales = (torch.mean(x ** 2 + eps, keepdim=True)) ** -0.5 * log_scale.exp() # return x * scales @@ -388,7 +388,7 @@ def forward( channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim ctx.save_for_backward(x, scale) - return _log_norm(x, scale, channel_dim) + return _exp_norm(x, scale, channel_dim) return ans @@ -403,7 +403,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: scale.requires_grad = True with torch.enable_grad(): - ans = _log_norm(x, scale, ctx.channel_dim) + ans = _exp_norm(x, scale, ctx.channel_dim) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -417,7 +417,7 @@ def c(x): class BiasNorm(torch.nn.Module): """ - Comment not up-to-date. This is LogNorm. Will change docs later if + Comment not up-to-date. This is ExpNorm. Will change docs later if promising. This is intended to be a simpler, and hopefully cheaper, replacement for @@ -462,12 +462,12 @@ def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _log_norm(x, self.scale, self.channel_dim) + return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( self.scale, min=0.5, max=2.5, training=self.training) - ans = LogNormFunction.apply( + ans = ExpNormFunction.apply( x, scale, self.channel_dim, ) From b3883f68caef13f4e1d86757bece01a70bf38ff0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 21:37:37 +0800 Subject: [PATCH 0164/1191] Double in_proj scale and halve out_proj scale of feedforward; now 20, 0.025 --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 87b8da9eef..7c47751340 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1632,7 +1632,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=10.0) + initial_scale=20.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( @@ -1642,7 +1642,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): dropout_p=dropout, dropout_shared_dim=0, bias=True, - initial_scale=0.05, + initial_scale=0.025, ) self.out_whiten = Whiten( From bffec500dc94276c774048a4eb8ee84b99364f35 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 22:31:24 +0800 Subject: [PATCH 0165/1191] Make all output projections start off fairly small. --- egs/librispeech/ASR/zipformer/joiner.py | 6 +++--- egs/librispeech/ASR/zipformer/model.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index 0406efe834..dc82d2c2d4 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -29,9 +29,9 @@ def __init__( ): super().__init__() - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) + self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) + self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) + self.output_linear = ScaledLinear(joiner_dim, vocab_size, initial_scale=0.1) def forward( self, diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 3e17b9ebb9..5cd86b0972 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -99,10 +99,10 @@ def __init__( self.joiner = joiner self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_scale=0.25 + encoder_dim, vocab_size, initial_scale=0.1, ) self.simple_lm_proj = ScaledLinear( - decoder_dim, vocab_size, initial_scale=0.25 + decoder_dim, vocab_size, initial_scale=0.1, ) else: assert decoder is None @@ -113,7 +113,7 @@ def __init__( # Modules for CTC head self.ctc_output = nn.Sequential( nn.Dropout(p=0.1), - nn.Linear(encoder_dim, vocab_size), + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), nn.LogSoftmax(dim=-1), ) @@ -123,8 +123,8 @@ def __init__( else: assert attention_decoder is None - self.reconstruction_proj = torch.nn.Linear( - encoder_dim, 4 * encoder_embed.in_channels) + self.reconstruction_proj = ScaledLinear( + encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) def forward_encoder( From 79a16857a5e760c4b65a2d80a3f6ce2fc2551cbd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 22:37:08 +0800 Subject: [PATCH 0166/1191] Revert joiner.py changes --- egs/librispeech/ASR/zipformer/joiner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index dc82d2c2d4..0406efe834 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -29,9 +29,9 @@ def __init__( ): super().__init__() - self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) - self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) - self.output_linear = ScaledLinear(joiner_dim, vocab_size, initial_scale=0.1) + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) def forward( self, From b760e80bb6e83a0c70f0ca881111d8b8422f8e34 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Feb 2025 22:52:26 +0800 Subject: [PATCH 0167/1191] Revert initial scales of ff module to 5,0.1 --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7c47751340..4805f085a5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1632,7 +1632,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=20.0) + initial_scale=5.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( @@ -1642,7 +1642,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): dropout_p=dropout, dropout_shared_dim=0, bias=True, - initial_scale=0.025, + initial_scale=0.1, ) self.out_whiten = Whiten( From 5bdd1a338338d9fde2ea24b8f1df01ba488d6688 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Feb 2025 19:14:31 +0800 Subject: [PATCH 0168/1191] Implement DigitalSwooshL. --- egs/librispeech/ASR/zipformer/scaling.py | 23 ++++++++++++++++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 234d2372cc..03a889e392 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1646,6 +1646,25 @@ def SwooshRForward(x: Tensor): return log_sum - 0.08 * x - 0.313261687 + +def digital_swooshl_forward(x): + # from wolfram alpha, comparing with swooshl: + #plot[ .25 * (log(1 + exp(4*x-4)) - .08*(4*x) - .035) ],[ -.08 * (x- -.06) +.04 * max(x- -.06, 0) +.15 * max(x-.5, 0) +.15 * max(x-.7, 0) +.25 * max(x-1, 0) +.25*max(x-1.2,0) ]for x=-1 to 2, + + x6 = x - -0.06 + return -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + + +def digital_swooshl_forward_and_deriv(x): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + x6 = x - -0.06 + y = -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + y.backward(gradient=torch.ones_like(y)) + return y, x.grad + + class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd @@ -1676,6 +1695,7 @@ def forward( forward_activation_dict = { "SwooshL": k2.swoosh_l_forward, "SwooshR": k2.swoosh_r_forward, + "DigitalSwooshL": digital_swooshl_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1695,6 +1715,7 @@ def backward(ctx, ans_grad: Tensor): forward_and_deriv_activation_dict = { "SwooshL": k2.swoosh_l_forward_and_deriv, "SwooshR": k2.swoosh_r_forward_and_deriv, + "DigitalSwooshL": digital_swooshl_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1780,6 +1801,8 @@ def forward(self, x: Tensor): x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) + elif self.activation == "DigitalSwooshL": + x = digital_swooshl_forward(x) else: assert False, self.activation return torch.nn.functional.linear(x, self.weight, self.bias) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6beb5bd122..1c18ac72f8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1634,7 +1634,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): self.out_proj = ActivationDropoutAndLinear( feedforward_dim, embed_dim, - activation="SwooshL", + activation="DigitalSwooshL", dropout_p=dropout, dropout_shared_dim=0, bias=True, From 587ef615dc711f7370691b8db01708063ec186e5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Feb 2025 20:50:39 +0800 Subject: [PATCH 0169/1191] Change in_proj and out_proj scales of feedforward by factor of 5 --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1c18ac72f8..ea0f48bf02 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1627,8 +1627,7 @@ class FeedforwardModule(nn.Module): def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. - self.in_proj = ScaledLinear(embed_dim, feedforward_dim, - initial_scale=5.0) + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( @@ -1638,7 +1637,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): dropout_p=dropout, dropout_shared_dim=0, bias=True, - initial_scale=0.1, + initial_scale=0.5, ) self.out_whiten = Whiten( From 535060dbe3f167f7eb53d1ce8c692a3d58395df1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Feb 2025 21:45:43 +0800 Subject: [PATCH 0170/1191] Change digitalswooshl formula to be more like swooshl. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 03a889e392..9b7cfb1775 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1652,7 +1652,7 @@ def digital_swooshl_forward(x): #plot[ .25 * (log(1 + exp(4*x-4)) - .08*(4*x) - .035) ],[ -.08 * (x- -.06) +.04 * max(x- -.06, 0) +.15 * max(x-.5, 0) +.15 * max(x-.7, 0) +.25 * max(x-1, 0) +.25*max(x-1.2,0) ]for x=-1 to 2, x6 = x - -0.06 - return -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + return -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + 0.16 * (x - 1.8).relu() def digital_swooshl_forward_and_deriv(x): @@ -1660,7 +1660,7 @@ def digital_swooshl_forward_and_deriv(x): x = x.detach() x.requires_grad = True x6 = x - -0.06 - y = -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + y = -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + 0.16 * (x - 1.8).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From a32edf88c128665009d6087eafea86e16fafba1c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Feb 2025 14:35:00 +0800 Subject: [PATCH 0171/1191] Shift DigitalSwooshL horitontally so that leftmost kink is at x=0,y=0 --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9b7cfb1775..9ceda83bf8 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1651,16 +1651,18 @@ def digital_swooshl_forward(x): # from wolfram alpha, comparing with swooshl: #plot[ .25 * (log(1 + exp(4*x-4)) - .08*(4*x) - .035) ],[ -.08 * (x- -.06) +.04 * max(x- -.06, 0) +.15 * max(x-.5, 0) +.15 * max(x-.7, 0) +.25 * max(x-1, 0) +.25*max(x-1.2,0) ]for x=-1 to 2, - x6 = x - -0.06 - return -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + 0.16 * (x - 1.8).relu() + # the couple lines below are to shift the function so that the left-most discontinuity is at + # x=0. + _x = x + -0.06 + return -0.08 * x + 0.04 * x.relu() + 0.15 * (_x - 0.5).relu() + 0.15 * (_x - 0.7).relu() + 0.25 * (_x - 1.0).relu() + 0.25 * (_x - 1.2).relu() + 0.16 * (_x - 1.8).relu() def digital_swooshl_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - x6 = x - -0.06 - y = -0.08 * x6 + 0.04 * x6.relu() + 0.15 * (x - 0.5).relu() + 0.15 * (x - 0.7).relu() + 0.25 * (x - 1.0).relu() + 0.25 * (x - 1.2).relu() + 0.16 * (x - 1.8).relu() + _x = x + -0.06 + y = -0.08 * x + 0.04 * x.relu() + 0.15 * (_x - 0.5).relu() + 0.15 * (_x - 0.7).relu() + 0.25 * (_x - 1.0).relu() + 0.25 * (_x - 1.2).relu() + 0.16 * (_x - 1.8).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From bf87e32ff6d88339faf2351576c030e87d001412 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Feb 2025 19:51:06 +0800 Subject: [PATCH 0172/1191] Change remaining SwooshL and SwooshR instances to DigitalSwooshL. --- egs/librispeech/ASR/zipformer/scaling.py | 21 ++++++++++++++++++++ egs/librispeech/ASR/zipformer/subsampling.py | 9 +++++---- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ea4939cb02..c7e524a131 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1652,6 +1652,27 @@ def digital_swooshl_forward_and_deriv(x): y.backward(gradient=torch.ones_like(y)) return y, x.grad +class DigitalSwooshLFunction(torch.nn.Module): + @staticmethod + def forward(ctx, x: Tensor): + ctx.save_for_backward(x) + return digital_swooshl_forward(x) + + def backward(ctx, y_grad: Tensor): + # this could be optimized, we could compute the derivative directly rather than use backward(). + x, = ctx.saved_tensors + y, function_deriv = digital_swooshl_forward_and_deriv(x) + return y_grad * function_deriv + +class DigitalSwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Digital Swoosh-L activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return digital_swooshl_forward(x) + return DigitalSwooshLFunction.apply(x) + + + class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index a64b74df1c..54a90ae19c 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -32,6 +32,7 @@ ScaleGrad, ScheduledFloat, SwooshL, + DigitalSwooshL, SwooshR, Whiten, ) @@ -69,7 +70,7 @@ def __init__( in_channels=channels, out_channels=hidden_channels, kernel_size=1 ) - self.activation = SwooshL() + self.activation = DigitalSwooshL() self.pointwise_conv2 = ScaledConv2d( in_channels=hidden_channels, out_channels=channels, @@ -225,7 +226,7 @@ def __init__( padding=(0, 1), # (time, freq) ), ScaleGrad(0.2), - SwooshR(), + DigitalSwooshL(), nn.Conv2d( in_channels=layer1_channels, out_channels=layer2_channels, @@ -233,14 +234,14 @@ def __init__( stride=2, padding=0, ), - SwooshR(), + DigitalSwooshL(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=(1, 2), # (time, freq) ), - SwooshR(), + DigitalSwooshL(), ) # just one convnext layer diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 858e71fb0a..fc16e5fba1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1874,7 +1874,7 @@ def __init__( self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, - activation="SwooshR", + activation="DigitalSwooshL", dropout_p=0.0, initial_scale=0.05, ) From 5f6158e4396f278ba555ebbda4b44b22930eeed2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Feb 2025 22:38:15 +0800 Subject: [PATCH 0173/1191] bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c7e524a131..2451370393 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1652,12 +1652,13 @@ def digital_swooshl_forward_and_deriv(x): y.backward(gradient=torch.ones_like(y)) return y, x.grad -class DigitalSwooshLFunction(torch.nn.Module): +class DigitalSwooshLFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor): ctx.save_for_backward(x) return digital_swooshl_forward(x) + @staticmethod def backward(ctx, y_grad: Tensor): # this could be optimized, we could compute the derivative directly rather than use backward(). x, = ctx.saved_tensors From 77cd8f982c78c4488582be8e2ad4d607b285a617 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 12:48:49 +0800 Subject: [PATCH 0174/1191] Change digital swoosh to a centered swoosh that is locally like leakyrelu(negative_slope=0.12) --- egs/librispeech/ASR/zipformer/scaling.py | 25 +++++++++------------- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9ceda83bf8..ae997393d6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1647,22 +1647,17 @@ def SwooshRForward(x: Tensor): -def digital_swooshl_forward(x): - # from wolfram alpha, comparing with swooshl: - #plot[ .25 * (log(1 + exp(4*x-4)) - .08*(4*x) - .035) ],[ -.08 * (x- -.06) +.04 * max(x- -.06, 0) +.15 * max(x-.5, 0) +.15 * max(x-.7, 0) +.25 * max(x-1, 0) +.25*max(x-1.2,0) ]for x=-1 to 2, +def digital_swoosh_forward(x): + # from wolfram alpha, comparing with swoosh: + #plot[ .25 * (log(1 + exp(4*x-1.4)) - .08*(4*x) - .2) ], [.02*x + .15*max(x,0) + .2*max(x-.2, 0) + .25*max(x-.4, 0) + .25*max(x-.7, 0) + .1*max(-0.4-x, 0) ] for x=-2 to 2 + return .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() - # the couple lines below are to shift the function so that the left-most discontinuity is at - # x=0. - _x = x + -0.06 - return -0.08 * x + 0.04 * x.relu() + 0.15 * (_x - 0.5).relu() + 0.15 * (_x - 0.7).relu() + 0.25 * (_x - 1.0).relu() + 0.25 * (_x - 1.2).relu() + 0.16 * (_x - 1.8).relu() - -def digital_swooshl_forward_and_deriv(x): +def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - _x = x + -0.06 - y = -0.08 * x + 0.04 * x.relu() + 0.15 * (_x - 0.5).relu() + 0.15 * (_x - 0.7).relu() + 0.25 * (_x - 1.0).relu() + 0.25 * (_x - 1.2).relu() + 0.16 * (_x - 1.8).relu() + y = .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad @@ -1697,7 +1692,7 @@ def forward( forward_activation_dict = { "SwooshL": k2.swoosh_l_forward, "SwooshR": k2.swoosh_r_forward, - "DigitalSwooshL": digital_swooshl_forward, + "DigitalSwoosh": digital_swoosh_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1717,7 +1712,7 @@ def backward(ctx, ans_grad: Tensor): forward_and_deriv_activation_dict = { "SwooshL": k2.swoosh_l_forward_and_deriv, "SwooshR": k2.swoosh_r_forward_and_deriv, - "DigitalSwooshL": digital_swooshl_forward_and_deriv, + "DigitalSwoosh": digital_swoosh_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1803,8 +1798,8 @@ def forward(self, x: Tensor): x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) - elif self.activation == "DigitalSwooshL": - x = digital_swooshl_forward(x) + elif self.activation == "DigitalSwoosh": + x = digital_swoosh_forward(x) else: assert False, self.activation return torch.nn.functional.linear(x, self.weight, self.bias) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ea0f48bf02..16f562eddb 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1633,7 +1633,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): self.out_proj = ActivationDropoutAndLinear( feedforward_dim, embed_dim, - activation="DigitalSwooshL", + activation="DigitalSwoosh", dropout_p=dropout, dropout_shared_dim=0, bias=True, @@ -1649,7 +1649,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): def forward(self, x: Tensor): x = self.in_proj(x) - # out_proj contains SwooshL activation, then dropout, then linear. + # out_proj contains DigitalSwoosh activation, then dropout, then linear. x = self.out_proj(x) x = self.out_whiten(x) return x From 8dc09144fa477720b7496fa32b1f52472be41c3b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 13:29:44 +0800 Subject: [PATCH 0175/1191] Revert scalar_lr_scale from .025 to .1 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 67f0691782..ce5ba69abd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -411,7 +411,7 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.025, + scalar_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, weight_max_rms=1.0, From 4c3da3f694587c5ec4b3109c90ff6d24cadc8af6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 13:29:44 +0800 Subject: [PATCH 0176/1191] Revert scalar_lr_scale from .025 to .1 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 67f0691782..ce5ba69abd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -411,7 +411,7 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.025, + scalar_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, weight_max_rms=1.0, From 0d8cb08c8230e923a9bdca993f5410ba65b0483e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 16:17:16 +0800 Subject: [PATCH 0177/1191] Change remaining Swoosh instances to DigitalSwoosh, in zipformer.py and subsampling.py --- egs/librispeech/ASR/zipformer/subsampling.py | 9 +++++---- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index e7f7f7cbea..3eca9388f4 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -32,6 +32,7 @@ ScaleGrad, ScheduledFloat, SwooshL, + DigitalSwoosh, SwooshR, Whiten, ) @@ -69,7 +70,7 @@ def __init__( in_channels=channels, out_channels=hidden_channels, kernel_size=1 ) - self.activation = SwooshL() + self.activation = DigitalSwoosh() self.pointwise_conv2 = ScaledConv2d( in_channels=hidden_channels, out_channels=channels, @@ -225,7 +226,7 @@ def __init__( padding=(0, 1), # (time, freq) ), ScaleGrad(0.2), - SwooshR(), + DigitalSwoosh(), nn.Conv2d( in_channels=layer1_channels, out_channels=layer2_channels, @@ -233,14 +234,14 @@ def __init__( stride=2, padding=0, ), - SwooshR(), + DigitalSwoosh(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=(1, 2), # (time, freq) ), - SwooshR(), + DigitalSwoosh(), ) # just one convnext layer diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 16f562eddb..324a94bf19 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1870,7 +1870,7 @@ def __init__( self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, - activation="SwooshR", + activation="DigitalSwoosh", dropout_p=0.0, initial_scale=0.05, ) From 302c7f4ea25ae2384cd7d1198403f62e7fe20159 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 16:32:39 +0800 Subject: [PATCH 0178/1191] remove mistakenly merged changes --- egs/librispeech/ASR/zipformer/zipformer.py | 28 ++++++++++------------ 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 39d4316f3a..47436db964 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -509,19 +509,19 @@ def __init__( dropout=0.0, ) - self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) - - - self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - for _ in range(2) ] + self.conv_module = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) - self.scale_limiter = ScaleLimiter(max_scale=4.0) + self.scale_limiter = ScaleLimiter(max_scale=ScheduledFloat((0.0, 2.0), (10000.0, 0.5), default=2.0)) self.norm = BiasNorm(embed_dim) @@ -563,16 +563,12 @@ def forward( src = src + self.self_attn1(src, attn_weights) - src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ) src = src + self.feed_forward2(src) - src = src + self.self_attn2(src, attn_weights) - - src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) - - src = src + self.feed_forward3(src) - src = self.bypass(src_orig, src) src = self.scale_limiter(src) @@ -1874,7 +1870,7 @@ def __init__( self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, - activation="DigitalSwoosh", + activation="DigitalSwooshR", dropout_p=0.0, initial_scale=0.05, ) From 2f38dbb25c937acfa2f8374aa0ea33e3bde7b9b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 16:35:29 +0800 Subject: [PATCH 0179/1191] Fix bug in write_debug_info --- egs/librispeech/ASR/zipformer/optim.py | 3 +-- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ce5ba69abd..996557238d 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -316,10 +316,9 @@ def _write_debug_info(group, state, param_names, summary_writer): """ Writes to a Tensorboard, model-debugging information that was accumulated in debug_step. """ - cur_step = state["step"] debug_interval = group["debug_interval"] - try: + cur_step = state["step"] debug_info = state["debug_info_cpu"] except KeyError: return diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 47436db964..324a94bf19 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1870,7 +1870,7 @@ def __init__( self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, - activation="DigitalSwooshR", + activation="DigitalSwoosh", dropout_p=0.0, initial_scale=0.05, ) From ccd6e302c8844ba71c780775214df0ac44ddf746 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 22:47:30 +0800 Subject: [PATCH 0180/1191] Remove bypass-noise and add whitening, in ZipformerEncoder --- egs/librispeech/ASR/zipformer/zipformer.py | 38 ++++++---------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 324a94bf19..fd9a888b68 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -31,7 +31,6 @@ OrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, - ScaleLimiter, BiasNorm, ChunkCausalDepthwiseConv1d, Dropout2, @@ -521,8 +520,6 @@ def __init__( embed_dim, cnn_module_kernel, causal=causal ) - self.scale_limiter = ScaleLimiter(max_scale=ScheduledFloat((0.0, 2.0), (10000.0, 0.5), default=2.0)) - self.norm = BiasNorm(embed_dim) @@ -571,8 +568,6 @@ def forward( src = self.bypass(src_orig, src) - src = self.scale_limiter(src) - return self.norm(src) def streaming_forward( @@ -717,7 +712,6 @@ def __init__( num_layers: int, pos_dim: int, dropout: float, - bypass_noise: FloatLike = ScheduledFloat((0.0, 0.0), (4000.0, 0.2), (8000.0, 0.0)), ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( @@ -728,7 +722,14 @@ def __init__( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers - self.bypass_noise = copy.deepcopy(bypass_noise) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + def forward( self, @@ -759,8 +760,6 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): - bypass = self._add_noise_to_bypass(bypass) for i, mod in enumerate(self.layers): src = mod( @@ -773,30 +772,13 @@ def forward( # randomize_factor can be viewed as a simple version of an # importance-sampling factor. + src = self.whiten(src) + if num_channels > layer_dim: src = torch.cat((src, bypass), dim=-1) return src - def _add_noise_to_bypass(self, x: Tensor): - bypass_scale = float(self.bypass_noise) - # a simpler way to set the noise scale would be to use - # bypass_scale * (x ** 2).mean().sqrt(). Using - # 0.5 * ((x ** 2).mean() + 1.0) instead gives the same answer when the rms - # is 1.0, and a larger answer elsewhere, so it encourages the rms of - # x to be about 1.0. Using .mean(dim=-1, keepdim=True) instead of .mean(), i.e. per-frame - # magnitude, helps to keep the gradients more concentrated which, in fp16 - # training, should reduce certain biases caused by roundoff which otherwise - # tend to lead the embeddings to get smaller in scale. - noise_scale = (0.5 * bypass_scale) * ((x ** 2).mean(dim=-1, keepdim=True) + 1.0) - - if random.random() < 0.001: - logging.info(f"name={self.name}, x_rms={(x**2).mean().sqrt().item()}, bypass_scale={bypass_scale}, noise_rms={noise_scale.mean()}") - - - return x + torch.randn_like(x) * noise_scale - - def streaming_forward( self, src: Tensor, From 6401d9ac155d112022095ff55fb190771635d434 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Feb 2025 22:03:51 +0800 Subject: [PATCH 0181/1191] Add 1.0e-05 to digitalswoosh. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 95ab0ad4af..7302d3f8c6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1636,14 +1636,14 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # from wolfram alpha, comparing with swoosh: #plot[ .25 * (log(1 + exp(4*x-1.4)) - .08*(4*x) - .2) ], [.02*x + .15*max(x,0) + .2*max(x-.2, 0) + .25*max(x-.4, 0) + .25*max(x-.7, 0) + .1*max(-0.4-x, 0) ] for x=-2 to 2 - return .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + return 1.0e-05 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + y = 1.0e-05 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From ecf365ea30de4cc2caba4d0f5721b533f40474b4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Feb 2025 14:37:51 +0800 Subject: [PATCH 0182/1191] Increase offset of DigitalSwoosh from 1e-5 to 1e-4 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7302d3f8c6..41db87483a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1636,14 +1636,14 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # from wolfram alpha, comparing with swoosh: #plot[ .25 * (log(1 + exp(4*x-1.4)) - .08*(4*x) - .2) ], [.02*x + .15*max(x,0) + .2*max(x-.2, 0) + .25*max(x-.4, 0) + .25*max(x-.7, 0) + .1*max(-0.4-x, 0) ] for x=-2 to 2 - return 1.0e-05 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + return 1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = 1.0e-05 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + y = 1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From 92e5a9ed830fad4facf7bf5f544161e52a0e4699 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Feb 2025 15:55:31 +0800 Subject: [PATCH 0183/1191] Start off reconstruction_loss_scale at twice the final value, warm up with warm step; halve final value. --- egs/librispeech/ASR/zipformer/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index aa72aaa3e9..8f471423b0 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -476,8 +476,8 @@ def get_parser(): parser.add_argument( "--reconstruction-loss-scale", type=float, - default=0.01, - help="Scale for log-mel reconstruction loss.", + default=0.005, + help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", ) parser.add_argument( @@ -984,7 +984,10 @@ def compute_loss( if use_cr_ctc: loss += params.cr_loss_scale * cr_loss - loss += params.reconstruction_loss_scale * reconstruction_loss + reconstruction_loss_scale = (params.reconstruction_loss_scale * + max(1.0, 2.0 - 1.0 * (batch_idx_train / warm_step))) + + loss += reconstruction_loss_scale * reconstruction_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss From 19d856f2893c8a211b8256338ca97060c541f3fc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Feb 2025 18:29:40 +0800 Subject: [PATCH 0184/1191] Make DigitalSwooshL offset of 1.0e-04 negative. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 41db87483a..34695a5dd5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1636,14 +1636,14 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # from wolfram alpha, comparing with swoosh: #plot[ .25 * (log(1 + exp(4*x-1.4)) - .08*(4*x) - .2) ], [.02*x + .15*max(x,0) + .2*max(x-.2, 0) + .25*max(x-.4, 0) + .25*max(x-.7, 0) + .1*max(-0.4-x, 0) ] for x=-2 to 2 - return 1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + return -1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = 1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + y = -1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From 67d3021511a059f74d8a4f0936c0c2cbe5b990bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Feb 2025 12:46:48 +0800 Subject: [PATCH 0185/1191] Add the extra modules in each layer, as in 188conv --- egs/librispeech/ASR/zipformer/zipformer.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fd9a888b68..2788779f10 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -508,17 +508,17 @@ def __init__( dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] + + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout - ) self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.conv_module = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + + self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + for _ in range(2) ] self.norm = BiasNorm(embed_dim) @@ -560,12 +560,16 @@ def forward( src = src + self.self_attn1(src, attn_weights) - src = src + self.conv_module( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ) + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) + src = src + self.self_attn2(src, attn_weights) + + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward3(src) + src = self.bypass(src_orig, src) return self.norm(src) From b204a4f7a197f6f081221e8e92eede544d4e5e59 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Feb 2025 13:16:36 +0800 Subject: [PATCH 0186/1191] Implement self-similar digital Swoosh. --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 34695a5dd5..416c444e94 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1634,16 +1634,18 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): - # from wolfram alpha, comparing with swoosh: - #plot[ .25 * (log(1 + exp(4*x-1.4)) - .08*(4*x) - .2) ], [.02*x + .15*max(x,0) + .2*max(x-.2, 0) + .25*max(x-.4, 0) + .25*max(x-.7, 0) + .1*max(-0.4-x, 0) ] for x=-2 to 2 - return -1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + # self-similar digital swoosh: + # (i.e the zoom-in near x=0 is very similar to the far-zoomed-out version) + # type this into wolfram alpha to see the graph vs. the log-add function: + #plot[ .25 * (log(1 + exp(4*x-2.2)) - .08*(4*x) - .08) ], [-1e-04 -.01*x + .1*max(x,0) + .065*max(-.25-x,0) + .15*max(x-.25,0) + .25*max(x-.5,0) + .2*max(x-.75,0) + .2*max(x-1.,0)] for x = -2 to 2 + return -1.0e-04 + .01 * x + .1 * x.relu() + .065*(-.25-x).relu() + .15 * (x-.25).relu() + .25 * (x-.5).relu() + .2*(x-.75).relu() + .2 * (x-1).relu() def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = -1.0e-04 + .02 * x + .15 * x.relu() + .2 * (x-.2).relu() + .25 * (x-.4).relu() + .25 * (x-.7).relu() + .1 * (-0.4-x).relu() + y = -1.0e-04 + .01 * x + .1 * x.relu() + .065*(-.25-x).relu() + .15 * (x-.25).relu() + .25 * (x-.5).relu() + .2*(x-.75).relu() + .2 * (x-1).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From b066348e2227180355416c3651f4d75d12e14c8f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Feb 2025 14:28:44 +0800 Subject: [PATCH 0187/1191] Change settings for debug-writer to not lose things at end --- egs/librispeech/ASR/zipformer/optim.py | 3 ++- egs/librispeech/ASR/zipformer/train.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 996557238d..3e8a443ede 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -341,7 +341,8 @@ def _write_debug_info(group, state, param_names, summary_writer): for name, info in zip(param_names, debug_info[..., i].unbind(dim=1)): debug_str = f"debug/{legend}/{name}" for step, value in zip(steps.tolist(), info.tolist()): - summary_writer.add_scalar(debug_str, value, step) + if step >= 0: + summary_writer.add_scalar(debug_str, value, step) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index aa72aaa3e9..f66a1634c1 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1295,7 +1295,7 @@ def run(rank, world_size, args): # "optimizer.write_debug_info(), we can continue with training and let the # background thread take care of dumping those events at its own speed. tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard", - max_queue=10000000) + max_queue=100) else: tb_writer = None From 8cded45dd6f06bbf5f63e08b3cb7fc2b02bdb0e7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Feb 2025 14:41:20 +0800 Subject: [PATCH 0188/1191] Restore scale_limiter --- egs/librispeech/ASR/zipformer/zipformer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2788779f10..8f99bdebaa 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -30,6 +30,7 @@ OrthogonalLinearSpecial, OrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + ScaleLimiter, ActivationDropoutAndLinear, BiasNorm, ChunkCausalDepthwiseConv1d, @@ -520,6 +521,8 @@ def __init__( self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) for _ in range(2) ] + self.scale_limiter = ScaleLimiter(max_scale=2.0) + self.norm = BiasNorm(embed_dim) @@ -572,6 +575,8 @@ def forward( src = self.bypass(src_orig, src) + src = self.scale_limiter(src) + return self.norm(src) def streaming_forward( From 1b3e31281e3984043de6c2d3738bae750196535b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Feb 2025 15:28:11 +0800 Subject: [PATCH 0189/1191] Simplify initialization in subsampling.py, remove layer-skipping and whitening. --- egs/librispeech/ASR/zipformer/subsampling.py | 42 ++------------------ 1 file changed, 3 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 859f394b67..65cd03b71e 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -49,14 +49,10 @@ def __init__( channels: int, hidden_ratio: int = 3, kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None, ): super().__init__() self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate self.depthwise_conv = nn.Conv2d( in_channels=channels, @@ -72,39 +68,15 @@ def __init__( self.activation = DigitalSwoosh() - self.pointwise_conv2 = ScaledConv2d( + self.pointwise_conv2 = nn.Conv2d( in_channels=hidden_channels, out_channels=channels, kernel_size=1, - initial_scale=0.01, ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = ( - torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) - > layerdrop_rate - ) - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - # return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - def forward_internal( - self, x: Tensor, layer_skip_mask: Optional[Tensor] = None + def forward( + self, x: Tensor, ) -> Tensor: """ x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) @@ -117,16 +89,8 @@ def forward_internal( x = self.activation(x) x = self.pointwise_conv2(x) - if layer_skip_mask is not None: - x = x * layer_skip_mask - x = bypass + x - if x.requires_grad: - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) - return x def streaming_forward( From 87aadc7a7f3eb52acdc14834e36c8f81e0984369 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Feb 2025 16:56:50 +0800 Subject: [PATCH 0190/1191] Re-tuned self-similar digital swoosh. --- egs/librispeech/ASR/zipformer/scaling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 416c444e94..5c8b218f79 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1634,18 +1634,18 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): - # self-similar digital swoosh: + # re-tuned self-similar digital swoosh: # (i.e the zoom-in near x=0 is very similar to the far-zoomed-out version) # type this into wolfram alpha to see the graph vs. the log-add function: - #plot[ .25 * (log(1 + exp(4*x-2.2)) - .08*(4*x) - .08) ], [-1e-04 -.01*x + .1*max(x,0) + .065*max(-.25-x,0) + .15*max(x-.25,0) + .25*max(x-.5,0) + .2*max(x-.75,0) + .2*max(x-1.,0)] for x = -2 to 2 - return -1.0e-04 + .01 * x + .1 * x.relu() + .065*(-.25-x).relu() + .15 * (x-.25).relu() + .25 * (x-.5).relu() + .2*(x-.75).relu() + .2 * (x-1).relu() + #plot[ .25 * (log(1 + exp(4*x-2.2)) - .08*(4*x) - .08) ], [-1e-04 -.01*x + .075*max(x,0) + .07*max(-.25-x,0) + .05*max(x-.15,0) + .2*max(x-.3,0) + .25*max(x-.6,0) + .3*max(x-.9,0)] for x = -2 to 2 + return -1.0e-04 + .01 * x + .075 * x.relu() + .07*(-.25-x).relu() + .05 * (x-.15).relu() + .2 * (x-.3).relu() + .25*(x-.6).relu() + .3 * (x-.9).relu() def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = -1.0e-04 + .01 * x + .1 * x.relu() + .065*(-.25-x).relu() + .15 * (x-.25).relu() + .25 * (x-.5).relu() + .2*(x-.75).relu() + .2 * (x-1).relu() + y = -1.0e-04 + .01 * x + .075 * x.relu() + .07*(-.25-x).relu() + .05 * (x-.15).relu() + .2 * (x-.3).relu() + .25*(x-.6).relu() + .3 * (x-.9).relu() y.backward(gradient=torch.ones_like(y)) return y, x.grad From f03e97b977c5e7d4125b0405ee98f8c0441cadcf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 14:37:23 +0800 Subject: [PATCH 0191/1191] Power-based nonlinearity with power 1.7 --- egs/librispeech/ASR/zipformer/scaling.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5c8b218f79..fa81cbf799 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1634,18 +1634,22 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): - # re-tuned self-similar digital swoosh: - # (i.e the zoom-in near x=0 is very similar to the far-zoomed-out version) - # type this into wolfram alpha to see the graph vs. the log-add function: - #plot[ .25 * (log(1 + exp(4*x-2.2)) - .08*(4*x) - .08) ], [-1e-04 -.01*x + .075*max(x,0) + .07*max(-.25-x,0) + .05*max(x-.15,0) + .2*max(x-.3,0) + .25*max(x-.6,0) + .3*max(x-.9,0)] for x = -2 to 2 - return -1.0e-04 + .01 * x + .075 * x.relu() + .07*(-.25-x).relu() + .05 * (x-.15).relu() + .2 * (x-.3).relu() + .25*(x-.6).relu() + .3 * (x-.9).relu() + # power-based swooshy thing with power=1.7 + + power = 1.7 + x_abs = x.abs() + return torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + + def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = -1.0e-04 + .01 * x + .075 * x.relu() + .07*(-.25-x).relu() + .05 * (x-.15).relu() + .2 * (x-.3).relu() + .25*(x-.6).relu() + .3 * (x-.9).relu() + power = 1.7 + x_abs = x.abs() + y = torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) return y, x.grad From 2fc8066e27daf92687fe047ff9fe4bb770650c7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 14:42:57 +0800 Subject: [PATCH 0192/1191] Add offset -1.0e-03 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fa81cbf799..52d5d99352 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1638,7 +1638,7 @@ def digital_swoosh_forward(x): power = 1.7 x_abs = x.abs() - return torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + return -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) @@ -1649,7 +1649,7 @@ def digital_swoosh_forward_and_deriv(x): x.requires_grad = True power = 1.7 x_abs = x.abs() - y = torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + y = -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) return y, x.grad From 803ab87a20648883862fbf75d9ace65467c4650c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 14:58:36 +0800 Subject: [PATCH 0193/1191] Change power from 1.7 to 1.5. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 52d5d99352..59d9221c6a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1636,7 +1636,7 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # power-based swooshy thing with power=1.7 - power = 1.7 + power = 1.5 x_abs = x.abs() return -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) @@ -1647,7 +1647,7 @@ def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - power = 1.7 + power = 1.5 x_abs = x.abs() y = -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) From 999d1db1303495799bbf65172384cf9ce3631bdf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 15:21:16 +0800 Subject: [PATCH 0194/1191] Change power in power-swoosh from 1.5 to 2.0 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 59d9221c6a..b0966ba0f3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1636,7 +1636,7 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # power-based swooshy thing with power=1.7 - power = 1.5 + power = 2.0 x_abs = x.abs() return -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) @@ -1647,7 +1647,7 @@ def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - power = 1.5 + power = 2.0 x_abs = x.abs() y = -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) From 3e0a9035e15f725808e7241393941f53c137c596 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 17:39:23 +0800 Subject: [PATCH 0195/1191] Remove offset of -1.0e-03 in power-swoosh --- egs/librispeech/ASR/zipformer/scaling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b0966ba0f3..287ec4e60b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,10 +1635,9 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # power-based swooshy thing with power=1.7 - power = 2.0 x_abs = x.abs() - return -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + return torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) @@ -1649,7 +1648,7 @@ def digital_swoosh_forward_and_deriv(x): x.requires_grad = True power = 2.0 x_abs = x.abs() - y = -1.0e-03 + torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + y = torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) return y, x.grad From 2de393658764d58a5ca0d4dcad6e0231e3465a9d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 17:41:40 +0800 Subject: [PATCH 0196/1191] Increase power from 2.0 to 2.2. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 287ec4e60b..adbdfbef60 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,7 +1635,7 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # power-based swooshy thing with power=1.7 - power = 2.0 + power = 2.2 x_abs = x.abs() return torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) @@ -1646,7 +1646,7 @@ def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - power = 2.0 + power = 2.2 x_abs = x.abs() y = torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) From a27b9bc144caae4ae8ff96d1480091b27dd4aa25 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 17:50:35 +0800 Subject: [PATCH 0197/1191] Change power to 2.1 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index adbdfbef60..591ce58137 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,7 +1635,7 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): # power-based swooshy thing with power=1.7 - power = 2.2 + power = 2.1 x_abs = x.abs() return torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) @@ -1646,7 +1646,7 @@ def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - power = 2.2 + power = 2.1 x_abs = x.abs() y = torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) y.backward(gradient=torch.ones_like(y)) From d77e31315bda6d26c5f0fd2f0ea0e80a43f215f3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 19:53:21 +0800 Subject: [PATCH 0198/1191] Code changes, preparing to change positive cutoff to 0.5 but just debugging for now --- egs/librispeech/ASR/zipformer/scaling.py | 29 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 591ce58137..298740771a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1634,11 +1634,30 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): - # power-based swooshy thing with power=1.7 - power = 2.1 + pos_power = 2.1 + pos_cutoff = 1.0 # x cutoff where it becomes linear for x>0 + + neg_power = 2.1 + neg_cutoff = 1.0 + neg_coeff = 0.1 + x_abs = x.abs() - return torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + pos_cutoff_y = pos_cutoff ** pos_power + pos_cutoff_dy_dx = pos_power * (pos_cutoff ** (pos_power - 1)) + pos_cutoff_offset = pos_cutoff_y - (pos_cutoff_dy_dx * pos_cutoff) + + neg_cutoff_y = neg_coeff * (neg_cutoff ** neg_power) + neg_cutoff_dy_dx = neg_coeff * neg_power * (neg_cutoff ** (neg_power - 1)) + neg_cutoff_offset = neg_cutoff_y - (neg_cutoff_dy_dx - neg_cutoff) + + y_pos = torch.where(x_abs > pos_cutoff, + x_abs ** pos_power, + x_abs * pos_cutoff_dy_dx + pos_cutoff_offset) + y_neg = torch.where(x_abs > neg_cutoff, + x_abs ** neg_power, + x_abs * neg_cutoff_dy_dx + neg_cutoff_offset) + return torch.where(x > 0, y_pos, y_neg) @@ -1646,9 +1665,7 @@ def digital_swoosh_forward_and_deriv(x): with torch.enable_grad(): x = x.detach() x.requires_grad = True - power = 2.1 - x_abs = x.abs() - y = torch.where(x_abs < 1, x_abs ** power, power * x_abs + (1 - power)) * torch.where(x > 0, 1.0, 0.1) + y = digital_swoosh_forward(x) y.backward(gradient=torch.ones_like(y)) return y, x.grad From 492fabea55ca709565bf08ee8cdde06afdf955a1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 19:56:56 +0800 Subject: [PATCH 0199/1191] add try/exceot in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3e8a443ede..a08f4b4ff8 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -639,7 +639,10 @@ def write_debug_info(self, summary_writer): with self.batched_params(group["params"], group_params_names) as batches: for _p, state, names in batches: _write_debug_info(group, state, names, summary_writer) - del state["debug_info_cpu"] + try: + del state["debug_info_cpu"] + except: + pass def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] From f181dad8c7023d8b678838144f62a58640b6de83 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 19:58:07 +0800 Subject: [PATCH 0200/1191] Bug fix in formula for y_neg --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 298740771a..2189add583 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1651,11 +1651,11 @@ def digital_swoosh_forward(x): neg_cutoff_dy_dx = neg_coeff * neg_power * (neg_cutoff ** (neg_power - 1)) neg_cutoff_offset = neg_cutoff_y - (neg_cutoff_dy_dx - neg_cutoff) - y_pos = torch.where(x_abs > pos_cutoff, + y_pos = torch.where(x_abs < pos_cutoff, x_abs ** pos_power, x_abs * pos_cutoff_dy_dx + pos_cutoff_offset) - y_neg = torch.where(x_abs > neg_cutoff, - x_abs ** neg_power, + y_neg = torch.where(x_abs < neg_cutoff, + neg_coeff * (x_abs ** neg_power), x_abs * neg_cutoff_dy_dx + neg_cutoff_offset) return torch.where(x > 0, y_pos, y_neg) From ad54d62d120c497b9cb07e5c608929104475fcd7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Feb 2025 21:18:25 +0800 Subject: [PATCH 0201/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2189add583..3e333fe2c3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1649,7 +1649,7 @@ def digital_swoosh_forward(x): neg_cutoff_y = neg_coeff * (neg_cutoff ** neg_power) neg_cutoff_dy_dx = neg_coeff * neg_power * (neg_cutoff ** (neg_power - 1)) - neg_cutoff_offset = neg_cutoff_y - (neg_cutoff_dy_dx - neg_cutoff) + neg_cutoff_offset = neg_cutoff_y - (neg_cutoff_dy_dx * neg_cutoff) y_pos = torch.where(x_abs < pos_cutoff, x_abs ** pos_power, From b6c677efc5f344143ba0e527be82240129539bac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 14:00:10 +0800 Subject: [PATCH 0202/1191] Decrease neg_cutoff from 1.0 to 0.75 and increase neg_coeff to 0.1333 (.1 * 4/3). --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3e333fe2c3..4569e99200 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1638,8 +1638,8 @@ def digital_swoosh_forward(x): pos_cutoff = 1.0 # x cutoff where it becomes linear for x>0 neg_power = 2.1 - neg_cutoff = 1.0 - neg_coeff = 0.1 + neg_cutoff = 0.75 + neg_coeff = 0.1333 x_abs = x.abs() From 19d3b46fb484dbcd319218f2cc9ed4dd95f6c010 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 15:09:40 +0800 Subject: [PATCH 0203/1191] Change to neg_cutoff=1.0, neg_coeff=0.08 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4569e99200..eda1510ffc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1638,8 +1638,8 @@ def digital_swoosh_forward(x): pos_cutoff = 1.0 # x cutoff where it becomes linear for x>0 neg_power = 2.1 - neg_cutoff = 0.75 - neg_coeff = 0.1333 + neg_cutoff = 1.0 + neg_coeff = 0.08 x_abs = x.abs() From 5448187c967a3a9b457edb01d5aaa214926cfaf5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 17:16:04 +0800 Subject: [PATCH 0204/1191] Introduce neg_power2=0.75, change neg_power and pos_power from 2.1 to 2.0, neg_coeff to 0.1 --- egs/librispeech/ASR/zipformer/scaling.py | 34 ++++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index eda1510ffc..313c899adc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1634,29 +1634,29 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): - pos_power = 2.1 - pos_cutoff = 1.0 # x cutoff where it becomes linear for x>0 + pos_power1 = 2.0 + pos_power2 = 1.0 - neg_power = 2.1 - neg_cutoff = 1.0 - neg_coeff = 0.08 + + neg_power1 = 2.0 + neg_power2 = 0.75 + + neg_coeff = 0.1 x_abs = x.abs() - pos_cutoff_y = pos_cutoff ** pos_power - pos_cutoff_dy_dx = pos_power * (pos_cutoff ** (pos_power - 1)) - pos_cutoff_offset = pos_cutoff_y - (pos_cutoff_dy_dx * pos_cutoff) + pos_power2_coeff = pos_power1 / pos_power2 + pos_offset = 1 - pos_power2_coeff - neg_cutoff_y = neg_coeff * (neg_cutoff ** neg_power) - neg_cutoff_dy_dx = neg_coeff * neg_power * (neg_cutoff ** (neg_power - 1)) - neg_cutoff_offset = neg_cutoff_y - (neg_cutoff_dy_dx * neg_cutoff) + neg_power2_coeff = neg_power1 / neg_power2 + neg_offset = 1 - neg_power2_coeff - y_pos = torch.where(x_abs < pos_cutoff, - x_abs ** pos_power, - x_abs * pos_cutoff_dy_dx + pos_cutoff_offset) - y_neg = torch.where(x_abs < neg_cutoff, - neg_coeff * (x_abs ** neg_power), - x_abs * neg_cutoff_dy_dx + neg_cutoff_offset) + y_pos = torch.where(x_abs < 1, + x_abs ** pos_power1, + (x_abs ** pos_power2) * pos_power2_coeff + pos_offset) + y_neg = torch.where(x_abs < 1, + x_abs ** neg_power1, + (x_abs ** neg_power2) * neg_power2_coeff + neg_offset) * neg_coeff return torch.where(x > 0, y_pos, y_neg) From 8f8a7fb23bd784d6b20c850a5e7f22c9dccbb47f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 17:31:13 +0800 Subject: [PATCH 0205/1191] Try to get rid of infinities in backprop --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 313c899adc..ac2933a636 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1644,6 +1644,7 @@ def digital_swoosh_forward(x): neg_coeff = 0.1 x_abs = x.abs() + x_abs_clamp = x_abs.clamp(min=1.) # trying avoid inf*0=nan in backprop. pos_power2_coeff = pos_power1 / pos_power2 pos_offset = 1 - pos_power2_coeff @@ -1653,10 +1654,10 @@ def digital_swoosh_forward(x): y_pos = torch.where(x_abs < 1, x_abs ** pos_power1, - (x_abs ** pos_power2) * pos_power2_coeff + pos_offset) + (x_abs_clamp ** pos_power2) * pos_power2_coeff + pos_offset) y_neg = torch.where(x_abs < 1, x_abs ** neg_power1, - (x_abs ** neg_power2) * neg_power2_coeff + neg_offset) * neg_coeff + (x_abs_clamp ** neg_power2) * neg_power2_coeff + neg_offset) * neg_coeff return torch.where(x > 0, y_pos, y_neg) From bb3f5a1d11e22398e10471e9fecc231df588b2ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 20:30:33 +0800 Subject: [PATCH 0206/1191] neg_power2=1.0, pos_power2=1.3 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ac2933a636..bd55e7172a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,11 +1635,11 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): pos_power1 = 2.0 - pos_power2 = 1.0 + pos_power2 = 1.3 neg_power1 = 2.0 - neg_power2 = 0.75 + neg_power2 = 1.0 neg_coeff = 0.1 From bcd2dd90ca56f0c9d4464ad05d6cd660c67d0af7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 20:38:24 +0800 Subject: [PATCH 0207/1191] Decrease pos_power to 0.8 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index bd55e7172a..5c2c9ce0cf 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,7 +1635,7 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): pos_power1 = 2.0 - pos_power2 = 1.3 + pos_power2 = 0.8 neg_power1 = 2.0 From cc1ec44f358f528eca515337af5fcf0f92977eb0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 22:46:56 +0800 Subject: [PATCH 0208/1191] Have negative linear term of -0.03 --- egs/librispeech/ASR/zipformer/scaling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5c2c9ce0cf..96db054ec2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,13 +1635,13 @@ def SwooshRForward(x: Tensor): def digital_swoosh_forward(x): pos_power1 = 2.0 - pos_power2 = 0.8 - + pos_power2 = 1.0 neg_power1 = 2.0 neg_power2 = 1.0 - neg_coeff = 0.1 + neg_coeff = 0.13 + linear_coeff = -0.03 x_abs = x.abs() x_abs_clamp = x_abs.clamp(min=1.) # trying avoid inf*0=nan in backprop. @@ -1658,7 +1658,7 @@ def digital_swoosh_forward(x): y_neg = torch.where(x_abs < 1, x_abs ** neg_power1, (x_abs_clamp ** neg_power2) * neg_power2_coeff + neg_offset) * neg_coeff - return torch.where(x > 0, y_pos, y_neg) + return x * linear_coeff + torch.where(x > 0, y_pos, y_neg) From 5f9afdadd3f14bf76697f62a5d0c3615f93325d8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Feb 2025 23:24:15 +0800 Subject: [PATCH 0209/1191] add .01 * x.relu() --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 96db054ec2..a3dad7e371 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1640,8 +1640,7 @@ def digital_swoosh_forward(x): neg_power1 = 2.0 neg_power2 = 1.0 - neg_coeff = 0.13 - linear_coeff = -0.03 + neg_coeff = 0.1 x_abs = x.abs() x_abs_clamp = x_abs.clamp(min=1.) # trying avoid inf*0=nan in backprop. @@ -1658,7 +1657,8 @@ def digital_swoosh_forward(x): y_neg = torch.where(x_abs < 1, x_abs ** neg_power1, (x_abs_clamp ** neg_power2) * neg_power2_coeff + neg_offset) * neg_coeff - return x * linear_coeff + torch.where(x > 0, y_pos, y_neg) + # add a little nonlinearity at origin: .01 * x.relu() + return torch.where(x > 0, y_pos, y_neg) + .01 * x.relu() From 7c73af60d7d00aa832b09b0c14287c55bce4fe29 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Feb 2025 10:55:44 +0800 Subject: [PATCH 0210/1191] Change activations back to SwooshL and SwooshR, and modify initialization to match. --- egs/librispeech/ASR/zipformer/subsampling.py | 15 ++++++++------- egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 65cd03b71e..969cb0f167 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -62,16 +62,17 @@ def __init__( padding=self.padding, ) - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1 + self.pointwise_conv1 = ScaledConv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1, initial_scale=4.0 ) - self.activation = DigitalSwoosh() + self.activation = SwooshL() - self.pointwise_conv2 = nn.Conv2d( + self.pointwise_conv2 = ScaledConv2d( in_channels=hidden_channels, out_channels=channels, kernel_size=1, + initial_scale=0.25, ) @@ -191,7 +192,7 @@ def __init__( padding=(0, 1), # (time, freq) ), ScaleGrad(0.2), - DigitalSwoosh(), + SwooshR(), nn.Conv2d( in_channels=layer1_channels, out_channels=layer2_channels, @@ -199,14 +200,14 @@ def __init__( stride=2, padding=0, ), - DigitalSwoosh(), + SwooshR(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=(1, 2), # (time, freq) ), - DigitalSwoosh(), + SwooshR(), ) # just one convnext layer diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8f99bdebaa..bf4504f466 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1618,17 +1618,17 @@ class FeedforwardModule(nn.Module): def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. - self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + self.in_proj = ScaledLinear(embed_dim, feedforward_dim, initial_scale=5.0) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( feedforward_dim, embed_dim, - activation="DigitalSwoosh", + activation="SwooshL", dropout_p=dropout, dropout_shared_dim=0, bias=True, - initial_scale=0.5, + initial_scale=0.1, ) self.out_whiten = Whiten( @@ -1861,7 +1861,7 @@ def __init__( self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, - activation="DigitalSwoosh", + activation="SwooshR", dropout_p=0.0, initial_scale=0.05, ) From b5291695f7837c1c4850ec00ed4c75918408988e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Feb 2025 16:17:01 +0800 Subject: [PATCH 0211/1191] Reduce scalar_lr_scale from .1 to .05 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a08f4b4ff8..a419d6091a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -411,7 +411,7 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.1, + scalar_lr_scale=0.05, eps=1.0e-08, weight_min_rms=0.005, weight_max_rms=1.0, From a342e3febbadb1a5d7d844475adfcb64d3c0a820 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 27 Feb 2025 19:51:26 +0800 Subject: [PATCH 0212/1191] Introduce scales of 4 and 0.25 before/after swoosh; change initialization scales to compensate. --- egs/librispeech/ASR/zipformer/scaling.py | 197 ++----------------- egs/librispeech/ASR/zipformer/subsampling.py | 8 +- egs/librispeech/ASR/zipformer/zipformer.py | 5 +- 3 files changed, 17 insertions(+), 193 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a3dad7e371..1796561513 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1464,134 +1464,25 @@ def forward(self, x: Tensor) -> Tensor: -class SwooshLFunction(torch.autograd.Function): - """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16 or x.dtype == torch.bfloat16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.get_autocast_gpu_dtype()) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swoosh-L activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 if not x.requires_grad: - return k2.swoosh_l_forward(x) + return 0.25 * k2.swoosh_l_forward(x * 4) else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) + return 0.25 * k2.swoosh_l(x * 4) class SwooshLOnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swoosh-L activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - - if x.dtype == torch.float16 or x.dtype == torch.bfloat16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + return 0.25 * logaddexp_onnx(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 - if not requires_grad: - return y - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = -0.08 - ceil = 0.925 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.get_autocast_gpu_dtype()) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d class SwooshR(torch.nn.Module): @@ -1599,96 +1490,36 @@ def forward(self, x: Tensor) -> Tensor: """Return Swoosh-R activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 if not x.requires_grad: - return k2.swoosh_r_forward(x) + return 0.25 * k2.swoosh_r_forward(4 * x) else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) + return 0.25 * k2.swoosh_r(4 * x) class SwooshROnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swoosh-R activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 + return 0.25 * logaddexp_onnx(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 # simple version of SwooshL that does not redefine the backprop, used in # ActivationDropoutAndLinearFunction. def SwooshLForward(x: Tensor): - x_offset = x - 4.0 + x_offset = 4 * x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 + return 0.25 * log_sum - 0.08 * x - 0.00875 # simple version of SwooshR that does not redefine the backprop, used in # ActivationDropoutAndLinearFunction. def SwooshRForward(x: Tensor): - x_offset = x - 1.0 + x_offset = 4 * x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - - -def digital_swoosh_forward(x): - pos_power1 = 2.0 - pos_power2 = 1.0 - - neg_power1 = 2.0 - neg_power2 = 1.0 - - neg_coeff = 0.1 - - x_abs = x.abs() - x_abs_clamp = x_abs.clamp(min=1.) # trying avoid inf*0=nan in backprop. - - pos_power2_coeff = pos_power1 / pos_power2 - pos_offset = 1 - pos_power2_coeff - - neg_power2_coeff = neg_power1 / neg_power2 - neg_offset = 1 - neg_power2_coeff - - y_pos = torch.where(x_abs < 1, - x_abs ** pos_power1, - (x_abs_clamp ** pos_power2) * pos_power2_coeff + pos_offset) - y_neg = torch.where(x_abs < 1, - x_abs ** neg_power1, - (x_abs_clamp ** neg_power2) * neg_power2_coeff + neg_offset) * neg_coeff - # add a little nonlinearity at origin: .01 * x.relu() - return torch.where(x > 0, y_pos, y_neg) + .01 * x.relu() - - - -def digital_swoosh_forward_and_deriv(x): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = digital_swoosh_forward(x) - y.backward(gradient=torch.ones_like(y)) - return y, x.grad - -class DigitalSwooshFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor): - ctx.save_for_backward(x) - return digital_swoosh_forward(x) - - @staticmethod - def backward(ctx, y_grad: Tensor): - # this could be optimized, we could compute the derivative directly rather than use backward(). - x, = ctx.saved_tensors - y, function_deriv = digital_swoosh_forward_and_deriv(x) - return y_grad * function_deriv - -class DigitalSwoosh(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Digital Swoosh-L activation.""" - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return digital_swoosh_forward(x) - return DigitalSwooshFunction.apply(x) + return 0.25 * log_sum - 0.08 * x - 0.07831542175 @@ -1723,7 +1554,6 @@ def forward( forward_activation_dict = { "SwooshL": k2.swoosh_l_forward, "SwooshR": k2.swoosh_r_forward, - "DigitalSwoosh": digital_swoosh_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1743,7 +1573,6 @@ def backward(ctx, ans_grad: Tensor): forward_and_deriv_activation_dict = { "SwooshL": k2.swoosh_l_forward_and_deriv, "SwooshR": k2.swoosh_r_forward_and_deriv, - "DigitalSwoosh": digital_swoosh_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1829,8 +1658,6 @@ def forward(self, x: Tensor): x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) - elif self.activation == "DigitalSwoosh": - x = digital_swoosh_forward(x) else: assert False, self.activation return torch.nn.functional.linear(x, self.weight, self.bias) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 969cb0f167..b007455f1b 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -32,7 +32,6 @@ ScaleGrad, ScheduledFloat, SwooshL, - DigitalSwoosh, SwooshR, Whiten, ) @@ -62,17 +61,16 @@ def __init__( padding=self.padding, ) - self.pointwise_conv1 = ScaledConv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1, initial_scale=4.0 + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1, ) self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( + self.pointwise_conv2 = nn.Conv2d( in_channels=hidden_channels, out_channels=channels, kernel_size=1, - initial_scale=0.25, ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bf4504f466..213de795db 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1618,7 +1618,7 @@ class FeedforwardModule(nn.Module): def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. - self.in_proj = ScaledLinear(embed_dim, feedforward_dim, initial_scale=5.0) + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( @@ -1628,7 +1628,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): dropout_p=dropout, dropout_shared_dim=0, bias=True, - initial_scale=0.1, + initial_scale=0.5, ) self.out_whiten = Whiten( @@ -1640,7 +1640,6 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): def forward(self, x: Tensor): x = self.in_proj(x) - # out_proj contains DigitalSwoosh activation, then dropout, then linear. x = self.out_proj(x) x = self.out_whiten(x) return x From 804621100eaf0119a605a0d3d1cbd3c94a306a6f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 28 Feb 2025 14:24:29 +0800 Subject: [PATCH 0213/1191] Fix memory leak --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 0d5cdbc216..7308b6884a 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1008,7 +1008,7 @@ def compute_loss( info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.use_cr_ctc: info["cr_loss"] = cr_loss.detach().cpu().item() - info["recon_loss"] = reconstruction_loss + info["recon_loss"] = reconstruction_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() From 43a789c0c8fa913cbd85fb424481d8fd17e043f8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 1 Mar 2025 19:24:44 +0800 Subject: [PATCH 0214/1191] Fix bugs RE scaled swoosh functions in scaling.py --- egs/librispeech/ASR/zipformer/scaling.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1796561513..2ed22303b6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1464,6 +1464,18 @@ def forward(self, x: Tensor) -> Tensor: +def _swoosh_l_forward_wrapper(x): + return 0.25 * k2.swoosh_l_forward(x * 4) +def _swoosh_r_forward_wrapper(x): + return 0.25 * k2.swoosh_r_forward(x * 4) +def _swoosh_l_forward_and_deriv_wrapper(x): + y, dy_dx = k2.swoosh_l_forward_and_deriv(x * 4) + return 0.25 * y, dy_dx +def _swoosh_r_forward_and_deriv_wrapper(x): + y, dy_dx = k2.swoosh_r_forward_and_deriv(x * 4) + return 0.25 * y, dy_dx + + class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -1472,7 +1484,7 @@ def forward(self, x: Tensor) -> Tensor: zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 if not x.requires_grad: - return 0.25 * k2.swoosh_l_forward(x * 4) + return _swoosh_l_forward_wrapper(x) else: return 0.25 * k2.swoosh_l(x * 4) @@ -1492,7 +1504,7 @@ def forward(self, x: Tensor) -> Tensor: zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 if not x.requires_grad: - return 0.25 * k2.swoosh_r_forward(4 * x) + return _swoosh_r_forward_wrapper(x) else: return 0.25 * k2.swoosh_r(4 * x) @@ -1552,8 +1564,8 @@ def forward( ctx.activation = activation forward_activation_dict = { - "SwooshL": k2.swoosh_l_forward, - "SwooshR": k2.swoosh_r_forward, + "SwooshL": _swoosh_l_forward_wrapper, + "SwooshR": _swoosh_r_forward_wrapper, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1571,8 +1583,8 @@ def backward(ctx, ans_grad: Tensor): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - "SwooshL": k2.swoosh_l_forward_and_deriv, - "SwooshR": k2.swoosh_r_forward_and_deriv, + "SwooshL": _swoosh_l_forward_and_deriv_wrapper, + "SwooshR": _swoosh_r_forward_and_deriv_wrapper, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. From a1821c4edb16f5a708a9aced438a47f00e5762d2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Mar 2025 00:08:44 +0800 Subject: [PATCH 0215/1191] Implement OrthogonalLinear more efficiently. --- egs/librispeech/ASR/zipformer/optim.py | 3 +- egs/librispeech/ASR/zipformer/scaling.py | 118 ++++++++++++++++------- 2 files changed, 85 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a419d6091a..a0e15236c8 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1243,7 +1243,7 @@ def step(self, closure=None): def _test_scaled_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear + from scaling import ScaledLinear #, OrthogonalLinear E = 100 B = 4 @@ -1266,6 +1266,7 @@ def _test_scaled_adam(hidden_dim: int): m = torch.nn.Sequential( Linear(E, hidden_dim), + #OrthogonalLinear(hidden_dim, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), torch.nn.PReLU(), diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2ed22303b6..2db5aa1f61 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -548,48 +548,100 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans + +class OrthogonalLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x, weight): + ctx.save_for_backward(x, weight) + return torch.matmul(x, weight.t()) + + @staticmethod + @custom_bwd + def backward(ctx, y_grad): + x, weight = ctx.saved_tensors + + if x.requires_grad: + x_grad = torch.matmul(y_grad, weight) + else: + x_grad = None + + if weight.requires_grad: + weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), + x.reshape(-1, x.shape[-1])) + + # now get extra gradient term that penalizes non-orthogonality. + weight = weight.detach() + + if weight.shape[0] > weight.shape[1]: + prod = torch.matmul(weight.t(), weight) + else: + prod = torch.matmul(weight, weight.t()) + + # we'll try to enforce that prod is any constant times the identity. + + # in the loss-function: + # orthognonality_loss = ((prod * alpha - I) ** 2).sum(), + # the following formula gives the alpha that means d(err)/d(scale-of-prod) will be zero. + # alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) + # we actually need 1/alpha: + inverse_alpha = (prod ** 2).sum(dim=1).mean(dim=0) / prod.diag().mean() + if random.random() < 0.01: + loss = ((prod / inverse_alpha - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype)) ** 2).mean() * prod.shape[0] + logging.info(f"inverse_alpha = {inverse_alpha.item()}, loss={loss.item()}") + # OK, imagining: + # err = 0.5 * ((prod ** 2).sum() * (alpha ** 2) + # - 2 * alpha * prod.diag().sum() + # = prod.shape[0]) + # do d(err)/d(prod) = (alpha**2) * prod - alpha * I + #... and we'll be normalizing out any scalar factor anyway, so by dividing + # by alpha**2 we can treat it as: + # d(err)/d(prod) = prod - I / alpha + prod_deriv = prod + N = prod.shape[0] + prod_deriv_diag = torch.as_strided(prod_deriv, size=(N,), stride=(N+1,)) + prod_deriv_diag.add_(-inverse_alpha) # modified prod_deriv in place + # now, assuming we had not done transpose above, d(err)/d(weight) = 2 * torch.matmul( + if weight.shape[0] > weight.shape[1]: + weight_err_grad = torch.matmul(weight, prod_deriv) + else: + weight_err_grad = torch.matmul(prod_deriv, weight) + # now scale weight_err_grad to have the same norm as weight grad: this will make sure + # it has about the required magnitude without overwhelming the main loss. + eps = torch.finfo(weight_err_grad.dtype).tiny + + # err_rel_scale is set less than one mostly so diagnostics about gradient scale will + # reflect something close to the actual gradient scale. This will be enough for it + # to fully enforce the constraint. + err_rel_scale = 0.25 + err_scale = err_rel_scale * weight_grad.abs().mean() / (weight_err_grad.abs().mean() + eps) + + weight_grad += err_scale * weight_err_grad + else: + weight_grad = None + return x_grad, weight_grad + + + class OrthogonalLinear(nn.Linear): + # penalty_scale does nothing, it is deprecated. def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): super().__init__(num_channels, num_channels, bias=False) - self.penalty_scale = copy.deepcopy(penalty_scale) - self.min_product_scale = 0.01 - self.name = None # will be set from training loop. for printing penalty. with torch.no_grad(): # this is not orthogonal but should quickly become so. self.weight[:] = torch.randn(num_channels, num_channels) * (num_channels ** -0.5) def forward(self, x: Tensor): - ans = nn.functional.linear(x, self.weight, self.bias) - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return ans - penalty_scale = float(self.penalty_scale) - if penalty_scale == 0.0: - return ans - weight = self.weight - if weight.shape[0] > weight.shape[1]: - weight = weight.t() - prod = torch.matmul(weight, weight.t()) # enforce that this is any constant times the identity. - # could include penalty_scale later on, but we do it at this point to make overflow of - # grads less likely (because they are aggregated earlier on, via sum()). - prod = scale_grad(prod, penalty_scale) - with torch.no_grad(): - alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) - alpha = alpha.clamp_(max=1. / self.min_product_scale) - - # following is equivalent to penalty_scale ((prod * alpha - I) ** - # 2).sum(), but more memory and compute efficient. - err = ((prod ** 2).sum() * (alpha ** 2) + - (-2 * alpha) * prod.diag().sum() + - prod.shape[0]) - - ans = with_loss(ans, err, self.name) - if random.random() < 0.001 or __name__ == '__main__': - with torch.no_grad(): - ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"{self.name}: product_scale={1/alpha}, dim={weight.shape}, avg_err = {err} * {penalty_scale} = {err*penalty_scale}, ans-rms={ans_rms}") + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.nn.functional.linear(x, self.weight, self.bias) + + ans = OrthogonalLinearFunction.apply(x, self.weight) + if self.bias is not None: + ans = ans + self.bias return ans + def OrthogonalLinearSpecial(num_channels: int, penalty_scale: float = 1000.0, transpose: bool = False): @@ -1885,10 +1937,6 @@ def _test_activation_dropout_and_linear(): x1 = torch.randn(10, in_channels) x1.requires_grad = True - # TEMP. - assert torch.allclose( - SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 - ) x2 = x1.clone().detach() x2.requires_grad = True From 6beacab7487d7af1cfc3005780602d3bfba42317 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Mar 2025 10:10:09 +0800 Subject: [PATCH 0216/1191] Implement a learning-rate factor of 0.5 for all non-residual components --- egs/librispeech/ASR/zipformer/model.py | 16 +++++++++++++++ egs/librispeech/ASR/zipformer/subsampling.py | 7 +++++++ egs/librispeech/ASR/zipformer/zipformer.py | 21 ++++++++++++-------- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 5cd86b0972..a4420deb91 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -95,6 +95,8 @@ def __init__( assert hasattr(decoder, "blank_id") assert joiner is not None + + self.decoder = decoder self.joiner = joiner @@ -104,6 +106,7 @@ def __init__( self.simple_lm_proj = ScaledLinear( decoder_dim, vocab_size, initial_scale=0.1, ) + else: assert decoder is None assert joiner is None @@ -127,6 +130,19 @@ def __init__( encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) + + + # lr_scale is a learning-rate factor for non-residual components; + # it will be interpreted by get_parameter_groups_with_lrs() + for m in ['decoder', 'joiner', 'simple_am_proj', 'simple_lm_proj', + 'reconstruction_proj', 'ctc_output']: + try: + module = getattr(self, m) + module.lr_scale = 0.5 + except AttributeError: # e.g. use_ctc == False + pass + + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index b007455f1b..c5bb059497 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -208,6 +208,7 @@ def __init__( SwooshR(), ) + # just one convnext layer self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) @@ -215,10 +216,16 @@ def __init__( self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.layer3_channels = layer3_channels + # scale it up a bit, else the output is quite small. self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=4.0) + # conv.lr_scale and out.lr_scale are learning-rate factors for non-residual components; + # they will be interpreted by get_parameter_groups_with_lrs(). + self.conv.lr_scale = 0.5 + self.out.lr_scale = 0.5 + self.out_limiter = ScaleLimiter(max_scale=4.0) # use a larger than normal grad_scale on this whitening module; there is diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 213de795db..75458651f2 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -143,7 +143,7 @@ def _to_tuple(x): self.chunk_size = chunk_size self.left_context_frames = left_context_frames - # each one will be Zipformer2Encoder or InvertibleDownsample or InvertibleUpsample + # each one will be Zipformer2Encoder or OrthogonalDownsample or OrthogonalUpsample encoders = [] num_encoders = len(downsampling_factor) @@ -159,11 +159,11 @@ def _to_tuple(x): def set_downsample_factor(cur_downsample, ds): while cur_downsample < ds: # need to downsample - encoders.append(InvertibleDownsample(channels=input_dim * cur_downsample, + encoders.append(OrthogonalDownsample(channels=input_dim * cur_downsample, proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) cur_downsample *= 2 while cur_downsample > ds: - encoders.append(InvertibleUpsample(channels=input_dim * cur_downsample, + encoders.append(OrthogonalUpsample(channels=input_dim * cur_downsample, proj_dim=min(input_dim * cur_downsample, max_proj_dim))) cur_downsample //= 2 return cur_downsample @@ -921,9 +921,9 @@ def forward(self, src_orig: Tensor, src: Tensor): -class InvertibleDownsample(torch.nn.Module): +class OrthogonalDownsample(torch.nn.Module): """ - Does downsampling in an invertible way, by a factor of two. Projection is initialized + Does downsampling with an orthogonal matrix, by a factor of two. Projection is initialized in a special way and enforced to be orthogonal. Args: @@ -943,6 +943,9 @@ def __init__( super().__init__() assert proj_dim <= channels * 2 self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) + # this is a learning-rate factor for non-residual components; lr_scale will be interpreted by + # get_parameter_groups_with_lrs(). + self.proj.lr_scale = 0.5 self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -973,10 +976,9 @@ def forward(self, src: Tensor) -> Tensor: src = self.proj(src) return src -class InvertibleUpsample(torch.nn.Module): +class OrthogonalUpsample(torch.nn.Module): """ - A very simple form of upsampling that is the inverse of InvertibleDownsampling. - Projection is initialized in a special way and enforced to be orthogonal. + A very simple form of upsampling with an orthogonal matrix. proj_dim: the number of channels that will actually be projected; the rest are just copied. proj_dim=channels would mean all channels are projected in a learned way @@ -990,6 +992,9 @@ def __init__(self, channels: int, proj_dim: int, super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) + # lr_scale is a learning-rate factor for non-residual components; it will be interpreted by + # get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.5 def forward(self, src: Tensor) -> Tensor: """ From 4b0c6ebffd5f7ab3613523e3542e88737026a690 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Mar 2025 10:20:10 +0800 Subject: [PATCH 0217/1191] print name in OrthogonalLinear --- egs/librispeech/ASR/zipformer/scaling.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2db5aa1f61..d7808420a7 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -552,8 +552,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: class OrthogonalLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight): + def forward(ctx, x, weight, name): ctx.save_for_backward(x, weight) + ctx.name = name return torch.matmul(x, weight.t()) @staticmethod @@ -586,9 +587,9 @@ def backward(ctx, y_grad): # alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) # we actually need 1/alpha: inverse_alpha = (prod ** 2).sum(dim=1).mean(dim=0) / prod.diag().mean() - if random.random() < 0.01: + if random.random() < 0.002: loss = ((prod / inverse_alpha - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype)) ** 2).mean() * prod.shape[0] - logging.info(f"inverse_alpha = {inverse_alpha.item()}, loss={loss.item()}") + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().item()}, loss={loss.item()}") # OK, imagining: # err = 0.5 * ((prod ** 2).sum() * (alpha ** 2) # - 2 * alpha * prod.diag().sum() @@ -619,7 +620,7 @@ def backward(ctx, y_grad): weight_grad += err_scale * weight_err_grad else: weight_grad = None - return x_grad, weight_grad + return x_grad, weight_grad, None @@ -627,6 +628,7 @@ class OrthogonalLinear(nn.Linear): # penalty_scale does nothing, it is deprecated. def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): super().__init__(num_channels, num_channels, bias=False) + self.name = None with torch.no_grad(): # this is not orthogonal but should quickly become so. @@ -636,7 +638,7 @@ def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): return torch.nn.functional.linear(x, self.weight, self.bias) - ans = OrthogonalLinearFunction.apply(x, self.weight) + ans = OrthogonalLinearFunction.apply(x, self.weight, self.name) if self.bias is not None: ans = ans + self.bias return ans From ea4bf4dd65175f8928b4011d9814636edfb98559 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Mar 2025 10:40:08 +0800 Subject: [PATCH 0218/1191] Restore warmup to defaults. --- egs/librispeech/ASR/zipformer/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7308b6884a..343711c2b2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1377,8 +1377,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, - warmup_start=0.1, warmup_batches=1000) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From db45fa9e631f761296f51822b1ce8260eda82bfd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Mar 2025 14:06:45 +0800 Subject: [PATCH 0219/1191] Add OrthogonalLinear with out_groups in SelfAttention. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- egs/librispeech/ASR/zipformer/scaling.py | 85 +++++++++++++++------- egs/librispeech/ASR/zipformer/zipformer.py | 19 ++--- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a0e15236c8..fefc922639 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1266,7 +1266,7 @@ def _test_scaled_adam(hidden_dim: int): m = torch.nn.Sequential( Linear(E, hidden_dim), - #OrthogonalLinear(hidden_dim, hidden_dim), + #OrthogonalLinear(hidden_dim, hidden_dim, bias=True, out_groups=1), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), torch.nn.PReLU(), diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d7808420a7..a12c66628e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -552,9 +552,12 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: class OrthogonalLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, name): + def forward(ctx, x, weight, name, in_groups, out_groups): ctx.save_for_backward(x, weight) ctx.name = name + ctx.out_groups = out_groups + ctx.in_groups = in_groups + assert not (in_groups > 1 and out_groups > 1) return torch.matmul(x, weight.t()) @staticmethod @@ -572,24 +575,37 @@ def backward(ctx, y_grad): x.reshape(-1, x.shape[-1])) # now get extra gradient term that penalizes non-orthogonality. - weight = weight.detach() + # reshape weight to (groups, a, b) with a <= b (the latter is for efficiency) + if ctx.out_groups > 1: + w = weight.reshape(ctx.out_groups, -1, weight.shape[1]) + elif ctx.in_groups > 1: + w = weight.reshape(weight.shape[0], ctx.in_groups, -1).transpose(0, 1) + else: + w = weight.unsqueeze(0) - if weight.shape[0] > weight.shape[1]: - prod = torch.matmul(weight.t(), weight) + if (w.shape[1] > w.shape[2]): + prod = torch.matmul(w.transpose(1, 2), w) else: - prod = torch.matmul(weight, weight.t()) + prod = torch.matmul(w, w.transpose(1, 2)) - # we'll try to enforce that prod is any constant times the identity. + # we'll try to enforce that for any i, prod[i] is any constant times the identity. # in the loss-function: # orthognonality_loss = ((prod * alpha - I) ** 2).sum(), # the following formula gives the alpha that means d(err)/d(scale-of-prod) will be zero. # alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) # we actually need 1/alpha: - inverse_alpha = (prod ** 2).sum(dim=1).mean(dim=0) / prod.diag().mean() + + # note, prod_diag shares memory with prod, this will matter later on. + (groups, r, c) = prod.shape + (groups_stride, r_stride, c_stride) = prod.stride() + prod_diag = torch.as_strided(prod, size=(groups, r), stride=(groups_stride, r_stride+c_stride)) + # inverse_alpha: (groups, 1) + inverse_alpha = (prod ** 2).sum(dim=2).mean(dim=1, keepdim=True) / prod_diag.mean(dim=1, keepdim=True) if random.random() < 0.002: - loss = ((prod / inverse_alpha - torch.eye(prod.shape[0], device=prod.device, dtype=prod.dtype)) ** 2).mean() * prod.shape[0] - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().item()}, loss={loss.item()}") + eye = torch.eye(prod.shape[1], device=prod.device, dtype=prod.dtype).unsqueeze(0) + loss = ((prod / inverse_alpha.unsqueeze(-1) - eye) ** 2).mean(dim=(1,2)) * prod.shape[1] + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}") # OK, imagining: # err = 0.5 * ((prod ** 2).sum() * (alpha ** 2) # - 2 * alpha * prod.diag().sum() @@ -599,46 +615,65 @@ def backward(ctx, y_grad): # by alpha**2 we can treat it as: # d(err)/d(prod) = prod - I / alpha prod_deriv = prod - N = prod.shape[0] - prod_deriv_diag = torch.as_strided(prod_deriv, size=(N,), stride=(N+1,)) - prod_deriv_diag.add_(-inverse_alpha) # modified prod_deriv in place - # now, assuming we had not done transpose above, d(err)/d(weight) = 2 * torch.matmul( - if weight.shape[0] > weight.shape[1]: - weight_err_grad = torch.matmul(weight, prod_deriv) + prod_deriv_diag = prod_diag # since prod_deriv shares memory with prod. + prod_deriv_diag.add_(-inverse_alpha) # modifies prod_deriv in place + + # differentiate through computation of prod: + if w.shape[1] > w.shape[2]: + w_grad = torch.matmul(w, prod_deriv) else: - weight_err_grad = torch.matmul(prod_deriv, weight) - # now scale weight_err_grad to have the same norm as weight grad: this will make sure + w_grad = torch.matmul(prod_deriv, w) + + + # now reshape back to weight_grad2 (call it weight_grad2 to distinguish + # from the gradient weight_grad that comes from the main loss) + if ctx.out_groups > 1: + weight_grad2 = w_grad.reshape(*weight.shape) + elif ctx.in_groups > 1: + weight_grad2 = w_grad.transpose(0, 1).reshape(*weight.shape) + else: + weight_grad2 = w_grad.squeeze(0) + + + + # now scale weight_grad2 to have the same norm as weight grad: this will make sure # it has about the required magnitude without overwhelming the main loss. - eps = torch.finfo(weight_err_grad.dtype).tiny + eps = torch.finfo(weight_grad2.dtype).tiny # err_rel_scale is set less than one mostly so diagnostics about gradient scale will # reflect something close to the actual gradient scale. This will be enough for it # to fully enforce the constraint. err_rel_scale = 0.25 - err_scale = err_rel_scale * weight_grad.abs().mean() / (weight_err_grad.abs().mean() + eps) + err_scale = err_rel_scale * weight_grad.abs().mean() / (weight_grad2.abs().mean() + eps) - weight_grad += err_scale * weight_err_grad + weight_grad += err_scale * weight_grad2 else: weight_grad = None - return x_grad, weight_grad, None + return x_grad, weight_grad, None, None, None class OrthogonalLinear(nn.Linear): # penalty_scale does nothing, it is deprecated. - def __init__(self, num_channels: int, penalty_scale: FloatLike = 1000.0): - super().__init__(num_channels, num_channels, bias=False) + # if in_groups or out_groups are set to >1, the orthogonal constraint + # will be set per group. both of them cannot be >1. + def __init__(self, in_channels: int, out_channels: int, + in_groups: int = 1, out_groups: int = 1, + bias: bool = True): + super().__init__(in_channels, out_channels, bias=bias) self.name = None + self.in_groups = in_groups + self.out_groups = out_groups with torch.no_grad(): # this is not orthogonal but should quickly become so. - self.weight[:] = torch.randn(num_channels, num_channels) * (num_channels ** -0.5) + self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): return torch.nn.functional.linear(x, self.weight, self.bias) - ans = OrthogonalLinearFunction.apply(x, self.weight, self.name) + ans = OrthogonalLinearFunction.apply(x, self.weight, self.name, self.in_groups, self.out_groups) if self.bias is not None: ans = ans + self.bias return ans diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 75458651f2..7aba53caa9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -27,7 +27,6 @@ from encoder_interface import EncoderInterface from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - OrthogonalLinearSpecial, OrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaleLimiter, @@ -933,16 +932,13 @@ class OrthogonalDownsample(torch.nn.Module): proj_dim=2 * channels would mean all channels are projected in a learned way causal: True for causal systems, only affects error messages as requires even input num frames. - penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because - it may interact with the scale of the loss function, i.e. if the loss-function - scale is smaller you may want this to be smaller. """ def __init__( - self, channels: int, proj_dim: int, causal: bool = False, penalty_scale: float = 1000.0, + self, channels: int, proj_dim: int, causal: bool = False, ): super().__init__() assert proj_dim <= channels * 2 - self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) # this is a learning-rate factor for non-residual components; lr_scale will be interpreted by # get_parameter_groups_with_lrs(). self.proj.lr_scale = 0.5 @@ -982,16 +978,12 @@ class OrthogonalUpsample(torch.nn.Module): proj_dim: the number of channels that will actually be projected; the rest are just copied. proj_dim=channels would mean all channels are projected in a learned way - penalty_scale: Penalty scale to enforce orthogonal projection; this is specifiable because - it may interact with the scale of the loss function, i.e. if the loss-function - scale is smaller you may want this to be smaller. """ - def __init__(self, channels: int, proj_dim: int, - penalty_scale: float = 1000.0): + def __init__(self, channels: int, proj_dim: int): super().__init__() assert proj_dim <= channels - self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) # lr_scale is a learning-rate factor for non-residual components; it will be interpreted by # get_parameter_groups_with_lrs() self.proj.lr_scale = 0.5 @@ -1504,7 +1496,8 @@ def __init__( value_head_dim: int, ) -> None: super().__init__() - self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, + bias=True, out_groups=num_heads) self.out_proj = ScaledLinear( num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 From c058219075a9145fb748bcde12b3a4ea49ecdb00 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 14:34:16 +0800 Subject: [PATCH 0220/1191] Change non-residual lr_scales from .5 to .75 --- egs/librispeech/ASR/zipformer/model.py | 2 +- egs/librispeech/ASR/zipformer/subsampling.py | 4 ++-- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index a4420deb91..37e0aa3a1f 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -138,7 +138,7 @@ def __init__( 'reconstruction_proj', 'ctc_output']: try: module = getattr(self, m) - module.lr_scale = 0.5 + module.lr_scale = 0.75 except AttributeError: # e.g. use_ctc == False pass diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index c5bb059497..7bd857abfa 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -223,8 +223,8 @@ def __init__( # conv.lr_scale and out.lr_scale are learning-rate factors for non-residual components; # they will be interpreted by get_parameter_groups_with_lrs(). - self.conv.lr_scale = 0.5 - self.out.lr_scale = 0.5 + self.conv.lr_scale = 0.75 + self.out.lr_scale = 0.75 self.out_limiter = ScaleLimiter(max_scale=4.0) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 75458651f2..4435b2b134 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -945,7 +945,7 @@ def __init__( self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) # this is a learning-rate factor for non-residual components; lr_scale will be interpreted by # get_parameter_groups_with_lrs(). - self.proj.lr_scale = 0.5 + self.proj.lr_scale = 0.75 self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -994,7 +994,7 @@ def __init__(self, channels: int, proj_dim: int, self.proj = OrthogonalLinear(proj_dim, penalty_scale=penalty_scale) # lr_scale is a learning-rate factor for non-residual components; it will be interpreted by # get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.5 + self.proj.lr_scale = 0.75 def forward(self, src: Tensor) -> Tensor: """ From 7ca795920caaa303535639a34721fdf07c918eb6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 16:15:19 +0800 Subject: [PATCH 0221/1191] Learn projections at lr_scale=0.8 not 0.75 and remove other instances of lr_scale. --- egs/librispeech/ASR/zipformer/model.py | 12 ------------ egs/librispeech/ASR/zipformer/subsampling.py | 5 ----- egs/librispeech/ASR/zipformer/zipformer.py | 12 ++++++------ 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 37e0aa3a1f..991bf78dff 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -131,18 +131,6 @@ def __init__( self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) - - # lr_scale is a learning-rate factor for non-residual components; - # it will be interpreted by get_parameter_groups_with_lrs() - for m in ['decoder', 'joiner', 'simple_am_proj', 'simple_lm_proj', - 'reconstruction_proj', 'ctc_output']: - try: - module = getattr(self, m) - module.lr_scale = 0.75 - except AttributeError: # e.g. use_ctc == False - pass - - def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 7bd857abfa..1b90826e1f 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -221,11 +221,6 @@ def __init__( self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=4.0) - # conv.lr_scale and out.lr_scale are learning-rate factors for non-residual components; - # they will be interpreted by get_parameter_groups_with_lrs(). - self.conv.lr_scale = 0.75 - self.out.lr_scale = 0.75 - self.out_limiter = ScaleLimiter(max_scale=4.0) # use a larger than normal grad_scale on this whitening module; there is diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 717493c9cf..4baf34a7c9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -939,9 +939,9 @@ def __init__( super().__init__() assert proj_dim <= channels * 2 self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # this is a learning-rate factor for non-residual components; lr_scale will be interpreted by - # get_parameter_groups_with_lrs(). - self.proj.lr_scale = 0.75 + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.8 self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -984,9 +984,9 @@ def __init__(self, channels: int, proj_dim: int): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # lr_scale is a learning-rate factor for non-residual components; it will be interpreted by - # get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.8 def forward(self, src: Tensor) -> Tensor: """ From 51839929361820af0c40a96e179d759ef6821cef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 17:08:03 +0800 Subject: [PATCH 0222/1191] Implement orthogonality constraint on key-projections. --- egs/librispeech/ASR/zipformer/optim.py | 5 +- egs/librispeech/ASR/zipformer/scaling.py | 107 ++++++++++++++++----- egs/librispeech/ASR/zipformer/zipformer.py | 15 ++- 3 files changed, 95 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index fefc922639..895e3a70fd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1243,7 +1243,7 @@ def step(self, closure=None): def _test_scaled_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear #, OrthogonalLinear + from scaling import ScaledLinear, OrthogonalLinear E = 100 B = 4 @@ -1266,7 +1266,8 @@ def _test_scaled_adam(hidden_dim: int): m = torch.nn.Sequential( Linear(E, hidden_dim), - #OrthogonalLinear(hidden_dim, hidden_dim, bias=True, out_groups=1), + OrthogonalLinear(hidden_dim, hidden_dim, bias=True, + in_groups=2, group_size=hidden_dim//4), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), torch.nn.PReLU(), diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a12c66628e..48dbaa5270 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -497,7 +497,6 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - ans.bias[:] = 0.0 torch.nn.init.uniform_(ans.bias, -0.01 * initial_scale, 0.01 * initial_scale) return ans @@ -552,12 +551,13 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: class OrthogonalLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, name, in_groups, out_groups): + def forward(ctx, x, weight, name, in_groups, out_groups, group_size): ctx.save_for_backward(x, weight) ctx.name = name ctx.out_groups = out_groups ctx.in_groups = in_groups - assert not (in_groups > 1 and out_groups > 1) + ctx.group_size = group_size + assert not (in_groups > 0 and out_groups > 0) return torch.matmul(x, weight.t()) @staticmethod @@ -570,19 +570,27 @@ def backward(ctx, y_grad): else: x_grad = None + + out_groups, in_groups, group_size = ctx.out_groups, ctx.in_groups, ctx.group_size + if weight.requires_grad: + + weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), x.reshape(-1, x.shape[-1])) - # now get extra gradient term that penalizes non-orthogonality. - # reshape weight to (groups, a, b) with a <= b (the latter is for efficiency) - if ctx.out_groups > 1: - w = weight.reshape(ctx.out_groups, -1, weight.shape[1]) - elif ctx.in_groups > 1: - w = weight.reshape(weight.shape[0], ctx.in_groups, -1).transpose(0, 1) + # Now get extra gradient term that penalizes non-orthogonality. + + # First get w which is of shape (num_groups, out_channels_per_group, in_channels_per_group) + if out_groups > 0: + w = weight[:out_groups*group_size].reshape(out_groups, group_size, weight.shape[1]) + elif in_groups > 0: + w = weight[:, :in_groups*group_size].reshape(weight.shape[0], in_groups, group_size).transpose(0, 1) else: w = weight.unsqueeze(0) + # Compute symmetric matrix-product prod with the smallest dimension + # possible given the shape of w. if (w.shape[1] > w.shape[2]): prod = torch.matmul(w.transpose(1, 2), w) else: @@ -618,24 +626,31 @@ def backward(ctx, y_grad): prod_deriv_diag = prod_diag # since prod_deriv shares memory with prod. prod_deriv_diag.add_(-inverse_alpha) # modifies prod_deriv in place - # differentiate through computation of prod: + # manually differentiate backward through computation of prod: if w.shape[1] > w.shape[2]: w_grad = torch.matmul(w, prod_deriv) else: w_grad = torch.matmul(prod_deriv, w) - # now reshape back to weight_grad2 (call it weight_grad2 to distinguish - # from the gradient weight_grad that comes from the main loss) - if ctx.out_groups > 1: - weight_grad2 = w_grad.reshape(*weight.shape) - elif ctx.in_groups > 1: - weight_grad2 = w_grad.transpose(0, 1).reshape(*weight.shape) + # now manually differentiate backward through the expression that computed w in the + # if-elif-else statement above. + if out_groups > 0: + d = out_groups * group_size + weight_grad2 = w_grad.reshape(d, -1) + if d < weight.shape[0]: + z = torch.zeros(weight.shape[0] - d, weight.shape[1], dtype=weight_grad2.dtype, device=weight_grad2.device) + weight_grad2 = torch.cat((weight_grad2, z), dim=0) + elif in_groups > 0: + d = in_groups * group_size + weight_grad2 = w_grad.transpose(0, 1).reshape(weight.shape[0], d) + if d < weight.shape[1]: + z = torch.zeros(weight.shape[0], weight.shape[1] - d, dtype=weight_grad2.dtype, device=weight_grad2.device) + weight_grad2 = torch.cat((weight_grad2, z), dim=1) else: weight_grad2 = w_grad.squeeze(0) - # now scale weight_grad2 to have the same norm as weight grad: this will make sure # it has about the required magnitude without overwhelming the main loss. eps = torch.finfo(weight_grad2.dtype).tiny @@ -649,31 +664,73 @@ def backward(ctx, y_grad): weight_grad += err_scale * weight_grad2 else: weight_grad = None - return x_grad, weight_grad, None, None, None + return x_grad, weight_grad, None, None, None, None class OrthogonalLinear(nn.Linear): - # penalty_scale does nothing, it is deprecated. + """ + Like nn.Linear but can enforce that the weight matrix, or selected parts of it, is + orthogonal up to a scalar factor. We are using a generalized definition of "orthogonal" + that applies to non-square matrix, i.e. that either M^T M or M M^T, whichever has + fewer rows/columns, should be equal to the identity times some positive scalar alpha. + (If M is square, these definitions are equivalent and is equivalent to the normal + definition of orthogonal). + + Args: + in_channels: number of input channels + out_channels: number of output channels + in_groups: the number of groups on the input dimension, if specified + the orthogonality-up-to-a-scalar-factor constraint will be + applied separately per group, with different scalars. + out_groups: the number of groups on the output dimension; you cannot + specify both this and in_groups with values >0. + group_size: the number of channels per group. This provides a way + to ensure that only part of the matrix is subject to the + orthogonality constraint, e.g. if you specified in_groups>0, + you can specify group_size + such that in_groups * group_size < in_channels, and the + remaining channels will be unconstrained. + bias: if True, include a bias term. + initial_scale: a factor that allows you to increase or decrease the + initial scale of the weight (and bias, if present) + + """ # if in_groups or out_groups are set to >1, the orthogonal constraint # will be set per group. both of them cannot be >1. - def __init__(self, in_channels: int, out_channels: int, - in_groups: int = 1, out_groups: int = 1, - bias: bool = True): + def __init__(self, + in_channels: int, + out_channels: int, + in_groups: int = -1, + out_groups: int = -1, + group_size: int = -1, + bias: bool = True, + initial_scale: float = 1.0, + ): super().__init__(in_channels, out_channels, bias=bias) self.name = None self.in_groups = in_groups self.out_groups = out_groups + if in_groups > 0 and group_size == -1: + group_size = in_channels // in_groups + elif out_groups > 0 and group_size == -1: + group_size = out_channels // out_groups + self.group_size = group_size + # the same scaling as for ScaledLinear. with torch.no_grad(): - # this is not orthogonal but should quickly become so. - self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) + self.weight[:] *= initial_scale + if self.bias is not None: + torch.nn.init.uniform_(self.bias, -0.01 * initial_scale, 0.01 * initial_scale) + def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): return torch.nn.functional.linear(x, self.weight, self.bias) - ans = OrthogonalLinearFunction.apply(x, self.weight, self.name, self.in_groups, self.out_groups) + ans = OrthogonalLinearFunction.apply(x, self.weight, self.name, + self.in_groups, self.out_groups, + self.group_size) if self.bias is not None: ans = ans + self.bias return ans diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4baf34a7c9..263cace154 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1175,8 +1175,13 @@ def __init__( # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + # The OrthogonalLinear will make sure that the rows of the projection to each + # key will be orthogonal, while leaving the queries and position-queries + # unconstrained. + self.in_proj = OrthogonalLinear( + embed_dim, in_proj_dim, + out_groups=num_heads, group_size=key_head_dim, + bias=True, initial_scale=query_head_dim**-0.25 ) self.whiten_keys = Whiten( @@ -1225,8 +1230,8 @@ def forward( query_dim = query_head_dim * num_heads # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] + k = x[..., 0:query_dim] + q = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] assert p.shape[-1] == num_heads * pos_head_dim, ( @@ -1236,7 +1241,7 @@ def forward( ) q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(k) # does nothing in the forward pass. + k = self.whiten_keys(k) # does nothing in the forward pass. [this may not really be needed due to the orthogonality constraint.] p = self.copy_pos_query(p) # for diagnostics only, does nothing. q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) From 7db6be3e35fbffb4eca794c9ff61e1cb13a8dbbb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 21:54:28 +0800 Subject: [PATCH 0223/1191] Try to get rid of divergence due to inf in err_scale --- egs/librispeech/ASR/zipformer/scaling.py | 48 +++++++++++++--------- egs/librispeech/ASR/zipformer/zipformer.py | 3 +- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 48dbaa5270..90837aef7c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -574,8 +574,6 @@ def backward(ctx, y_grad): out_groups, in_groups, group_size = ctx.out_groups, ctx.in_groups, ctx.group_size if weight.requires_grad: - - weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), x.reshape(-1, x.shape[-1])) @@ -583,11 +581,13 @@ def backward(ctx, y_grad): # First get w which is of shape (num_groups, out_channels_per_group, in_channels_per_group) if out_groups > 0: - w = weight[:out_groups*group_size].reshape(out_groups, group_size, weight.shape[1]) + func = lambda x: x[:out_groups*group_size].reshape(out_groups, group_size, x.shape[1]) elif in_groups > 0: - w = weight[:, :in_groups*group_size].reshape(weight.shape[0], in_groups, group_size).transpose(0, 1) + func = lambda x: x[:, :in_groups*group_size].reshape(x.shape[0], in_groups, group_size).transpose(0, 1) else: - w = weight.unsqueeze(0) + func = lambda x: x.unsqueeze(0) + w = func(weight) + w_orig_grad = func(weight_grad) # Compute symmetric matrix-product prod with the smallest dimension # possible given the shape of w. @@ -610,10 +610,12 @@ def backward(ctx, y_grad): prod_diag = torch.as_strided(prod, size=(groups, r), stride=(groups_stride, r_stride+c_stride)) # inverse_alpha: (groups, 1) inverse_alpha = (prod ** 2).sum(dim=2).mean(dim=1, keepdim=True) / prod_diag.mean(dim=1, keepdim=True) - if random.random() < 0.002: + + do_print = random.random() < 0.005 + if do_print: eye = torch.eye(prod.shape[1], device=prod.device, dtype=prod.dtype).unsqueeze(0) loss = ((prod / inverse_alpha.unsqueeze(-1) - eye) ** 2).mean(dim=(1,2)) * prod.shape[1] - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}") + # OK, imagining: # err = 0.5 * ((prod ** 2).sum() * (alpha ** 2) # - 2 * alpha * prod.diag().sum() @@ -633,6 +635,25 @@ def backward(ctx, y_grad): w_grad = torch.matmul(prod_deriv, w) + # we want to scale w_grad to have the same average element absolute-value + # as err_rel_scale times the average absolute-value in the + # corresponding part of weight_grad: this will make sure it has + # about the required magnitude without overwhelming the main loss. + eps = torch.finfo(w_grad.dtype).tiny + + # err_rel_scale is set less than one mostly so diagnostics about gradient scale will + # reflect something close to the actual gradient scale. This will be enough for it + # to fully enforce the constraint. + err_rel_scale = 0.05 + err_scale = err_rel_scale * (w_orig_grad.abs().mean(dim=(1,2), keepdim=True) / + w_grad.abs().mean(dim=(1,2), keepdim=True)) + err_scale = torch.nan_to_num(err_scale, nan=0.0, posinf=0.0, neginf=0.0) + + if do_print: + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, err_scale={err_scale.flatten()}") + + w_grad = w_grad * err_scale + # now manually differentiate backward through the expression that computed w in the # if-elif-else statement above. if out_groups > 0: @@ -650,18 +671,7 @@ def backward(ctx, y_grad): else: weight_grad2 = w_grad.squeeze(0) - - # now scale weight_grad2 to have the same norm as weight grad: this will make sure - # it has about the required magnitude without overwhelming the main loss. - eps = torch.finfo(weight_grad2.dtype).tiny - - # err_rel_scale is set less than one mostly so diagnostics about gradient scale will - # reflect something close to the actual gradient scale. This will be enough for it - # to fully enforce the constraint. - err_rel_scale = 0.25 - err_scale = err_rel_scale * weight_grad.abs().mean() / (weight_grad2.abs().mean() + eps) - - weight_grad += err_scale * weight_grad2 + weight_grad.add_(weight_grad2) else: weight_grad = None return x_grad, weight_grad, None, None, None, None diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 263cace154..99fce7bcf1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1229,7 +1229,8 @@ def forward( query_dim = query_head_dim * num_heads - # self-attention + ## self-attention + ## TODO: the keys have to be the leading dimension as we have orthogonality constraint on the keys. k = x[..., 0:query_dim] q = x[..., query_dim : 2 * query_dim] # p is the position-encoding query From 6fdcae80326f7213533f0baed4e7a7bfb91bc69d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 22:24:23 +0800 Subject: [PATCH 0224/1191] fix comment --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 99fce7bcf1..ad082d0840 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1229,8 +1229,8 @@ def forward( query_dim = query_head_dim * num_heads - ## self-attention - ## TODO: the keys have to be the leading dimension as we have orthogonality constraint on the keys. + # self-attention + # the keys have to come first as we have the orthogonality constraint on the keys. k = x[..., 0:query_dim] q = x[..., query_dim : 2 * query_dim] # p is the position-encoding query From 7a86d4d00b8ba886f5c3214a2bd3e951ebb461e9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 22:32:08 +0800 Subject: [PATCH 0225/1191] Set grad_scale=1.0e-04 in whiten_keys, to turn it off but still get log messages. revert err_rel_scale to 0.25 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 90837aef7c..856bbcbf0a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -644,7 +644,7 @@ def backward(ctx, y_grad): # err_rel_scale is set less than one mostly so diagnostics about gradient scale will # reflect something close to the actual gradient scale. This will be enough for it # to fully enforce the constraint. - err_rel_scale = 0.05 + err_rel_scale = 0.25 err_scale = err_rel_scale * (w_orig_grad.abs().mean(dim=(1,2), keepdim=True) / w_grad.abs().mean(dim=(1,2), keepdim=True)) err_scale = torch.nan_to_num(err_scale, nan=0.0, posinf=0.0, neginf=0.0) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ad082d0840..e2d16d9bff 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1188,7 +1188,7 @@ def __init__( num_groups=num_heads, whitening_limit=_whitening_schedule(3.0), prob=(0.025, 0.25), - grad_scale=0.025, + grad_scale=1.0e-05, ) # linear transformation for positional encoding. From e80068ad999d803412dda7d6c6bee4641a161b69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 22:36:46 +0800 Subject: [PATCH 0226/1191] Revert err_rel_scale from .05 to 0.25 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 90837aef7c..856bbcbf0a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -644,7 +644,7 @@ def backward(ctx, y_grad): # err_rel_scale is set less than one mostly so diagnostics about gradient scale will # reflect something close to the actual gradient scale. This will be enough for it # to fully enforce the constraint. - err_rel_scale = 0.05 + err_rel_scale = 0.25 err_scale = err_rel_scale * (w_orig_grad.abs().mean(dim=(1,2), keepdim=True) / w_grad.abs().mean(dim=(1,2), keepdim=True)) err_scale = torch.nan_to_num(err_scale, nan=0.0, posinf=0.0, neginf=0.0) From 8d7cb1077e2ec082b52e2d8b0150f698d90bca2b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Mar 2025 23:18:38 +0800 Subject: [PATCH 0227/1191] Avoid large scales getting turned to zero in OrthogonalLinearFunction --- egs/librispeech/ASR/zipformer/scaling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 856bbcbf0a..1cd9aeb9d6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -645,9 +645,10 @@ def backward(ctx, y_grad): # reflect something close to the actual gradient scale. This will be enough for it # to fully enforce the constraint. err_rel_scale = 0.25 - err_scale = err_rel_scale * (w_orig_grad.abs().mean(dim=(1,2), keepdim=True) / - w_grad.abs().mean(dim=(1,2), keepdim=True)) - err_scale = torch.nan_to_num(err_scale, nan=0.0, posinf=0.0, neginf=0.0) + err_scale = ((err_rel_scale * w_orig_grad.abs().mean(dim=(1,2), keepdim=True)) / + w_grad.abs().mean(dim=(1,2), keepdim=True) + eps) + # the 5000.0 is just a 'large scale' in case the division overflows in float16. + err_scale = torch.nan_to_num(err_scale, nan=0.0, posinf=5000.0, neginf=0.0) if do_print: logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, err_scale={err_scale.flatten()}") From 92e079342a7324b54b16fa951ace35f755314072 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Mar 2025 10:04:25 +0800 Subject: [PATCH 0228/1191] revert zipformer.py to the way it was before orthogonalizing the keys. --- egs/librispeech/ASR/zipformer/zipformer.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e2d16d9bff..ff84e68d93 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1175,12 +1175,8 @@ def __init__( # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - # The OrthogonalLinear will make sure that the rows of the projection to each - # key will be orthogonal, while leaving the queries and position-queries - # unconstrained. - self.in_proj = OrthogonalLinear( + self.in_proj = ScaledLinear( embed_dim, in_proj_dim, - out_groups=num_heads, group_size=key_head_dim, bias=True, initial_scale=query_head_dim**-0.25 ) @@ -1188,7 +1184,7 @@ def __init__( num_groups=num_heads, whitening_limit=_whitening_schedule(3.0), prob=(0.025, 0.25), - grad_scale=1.0e-05, + grad_scale=0.025, ) # linear transformation for positional encoding. @@ -1230,9 +1226,8 @@ def forward( query_dim = query_head_dim * num_heads # self-attention - # the keys have to come first as we have the orthogonality constraint on the keys. - k = x[..., 0:query_dim] - q = x[..., query_dim : 2 * query_dim] + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] assert p.shape[-1] == num_heads * pos_head_dim, ( From e6539b8ce694bcec6b496f1e77fd1b04ee17106f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 15:29:51 +0800 Subject: [PATCH 0229/1191] Fix missing parenthesis regarding err_scale --- egs/librispeech/ASR/zipformer/scaling.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1cd9aeb9d6..45fb88e912 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -646,9 +646,7 @@ def backward(ctx, y_grad): # to fully enforce the constraint. err_rel_scale = 0.25 err_scale = ((err_rel_scale * w_orig_grad.abs().mean(dim=(1,2), keepdim=True)) / - w_grad.abs().mean(dim=(1,2), keepdim=True) + eps) - # the 5000.0 is just a 'large scale' in case the division overflows in float16. - err_scale = torch.nan_to_num(err_scale, nan=0.0, posinf=5000.0, neginf=0.0) + (w_grad.abs().mean(dim=(1,2), keepdim=True) + eps)) if do_print: logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, err_scale={err_scale.flatten()}") From 7ea864e1d44f440c4fb8bec62c4be9dcf744aefb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 15:44:34 +0800 Subject: [PATCH 0230/1191] Revert initialization of OrthogonalLinear to the way it previously was --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 45fb88e912..d4555218bc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -728,7 +728,7 @@ def __init__(self, # the same scaling as for ScaledLinear. with torch.no_grad(): - self.weight[:] *= initial_scale + self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * initial_scale if self.bias is not None: torch.nn.init.uniform_(self.bias, -0.01 * initial_scale, 0.01 * initial_scale) From 60e9bca97eba1bd6e7c546fd1b9c64a80b58a2c2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 16:00:43 +0800 Subject: [PATCH 0231/1191] Reduce initial scale of ExpNorm from e-1 to 1.5 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d4555218bc..a1ecf26a2e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -453,7 +453,7 @@ def __init__( super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(1.718281828)) + self.scale = nn.Parameter(torch.tensor(1.5)) self.name = None From 7b44aaaf3cd042d7c026252303fdd33312de0ef0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 16:35:44 +0800 Subject: [PATCH 0232/1191] Change initial scale of expnorm to 2.0 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a1ecf26a2e..548950772b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -453,7 +453,7 @@ def __init__( super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(1.5)) + self.scale = nn.Parameter(torch.tensor(2.0)) self.name = None From 12eb04ab6d71df707c0563fbf1d22ef19f958fca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 17:31:32 +0800 Subject: [PATCH 0233/1191] Something natural-gradienty, like changing power of 1.0 in Adam to 1.2. --- egs/librispeech/ASR/zipformer/optim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 895e3a70fd..7a4fd7a517 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -146,7 +146,13 @@ def basic_step(group, p, state, grad): exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) denom = exp_avg_sq.sqrt().add_(eps) - return -lr * grad / denom + ## following three are tunable. + power = 0.2 + factor_max = 1.2 + factor_min = 0.8 + factor = ((denom / denom.mean()) ** power).clamp_(min=factor_min, max=factor_max) + + return -lr * grad / (denom * factor) def scaling_step(group, p, state, grad): From 69247dab7e3076c5c4b10a2b6f940f817a79cca0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 17:32:54 +0800 Subject: [PATCH 0234/1191] Reverse direction of modified scale in adam --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7a4fd7a517..25d011f210 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -152,7 +152,7 @@ def basic_step(group, p, state, grad): factor_min = 0.8 factor = ((denom / denom.mean()) ** power).clamp_(min=factor_min, max=factor_max) - return -lr * grad / (denom * factor) + return -lr * grad * factor / denom def scaling_step(group, p, state, grad): From e8b3d8a2988ddbff92c1901a5fb746a9d5b7d7ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 19:36:09 +0800 Subject: [PATCH 0235/1191] Set scale of loss to 1000 times average main-loss derivative magnitude --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 548950772b..ed2fa0b34c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -641,12 +641,10 @@ def backward(ctx, y_grad): # about the required magnitude without overwhelming the main loss. eps = torch.finfo(w_grad.dtype).tiny - # err_rel_scale is set less than one mostly so diagnostics about gradient scale will - # reflect something close to the actual gradient scale. This will be enough for it - # to fully enforce the constraint. - err_rel_scale = 0.25 - err_scale = ((err_rel_scale * w_orig_grad.abs().mean(dim=(1,2), keepdim=True)) / - (w_grad.abs().mean(dim=(1,2), keepdim=True) + eps)) + # err_rel_scale is the scale of the orthogonality loss function relative to the + # average scale of the "main" gradient term. + err_rel_scale = 1000.0 + err_scale = err_rel_scale * w_orig_grad.abs().mean(dim=(1,2), keepdim=True) if do_print: logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, err_scale={err_scale.flatten()}") From 00fe63589bfe69725996561921fd20271c3ffbbf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 20:23:29 +0800 Subject: [PATCH 0236/1191] remove proj.lr_scale = 0.8 --- egs/librispeech/ASR/zipformer/zipformer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ff84e68d93..8a1bd05a9b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -939,9 +939,6 @@ def __init__( super().__init__() assert proj_dim <= channels * 2 self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.8 self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -984,9 +981,7 @@ def __init__(self, channels: int, proj_dim: int): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.8 + def forward(self, src: Tensor) -> Tensor: """ From bcdfa73ae9e3f5fcbcb758245696bda330af30c9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 7 Mar 2025 10:58:02 +0800 Subject: [PATCH 0237/1191] Simplify code for OrthogonalLinearFunction --- egs/librispeech/ASR/zipformer/scaling.py | 151 +++++++++-------------- 1 file changed, 61 insertions(+), 90 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ed2fa0b34c..236c5f0117 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -577,98 +577,69 @@ def backward(ctx, y_grad): weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), x.reshape(-1, x.shape[-1])) - # Now get extra gradient term that penalizes non-orthogonality. + penalty_scale = 1000.0 * weight_grad.abs().mean() - # First get w which is of shape (num_groups, out_channels_per_group, in_channels_per_group) - if out_groups > 0: - func = lambda x: x[:out_groups*group_size].reshape(out_groups, group_size, x.shape[1]) - elif in_groups > 0: - func = lambda x: x[:, :in_groups*group_size].reshape(x.shape[0], in_groups, group_size).transpose(0, 1) - else: - func = lambda x: x.unsqueeze(0) - w = func(weight) - w_orig_grad = func(weight_grad) - - # Compute symmetric matrix-product prod with the smallest dimension - # possible given the shape of w. - if (w.shape[1] > w.shape[2]): - prod = torch.matmul(w.transpose(1, 2), w) - else: - prod = torch.matmul(w, w.transpose(1, 2)) - - # we'll try to enforce that for any i, prod[i] is any constant times the identity. - - # in the loss-function: - # orthognonality_loss = ((prod * alpha - I) ** 2).sum(), - # the following formula gives the alpha that means d(err)/d(scale-of-prod) will be zero. - # alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) - # we actually need 1/alpha: - - # note, prod_diag shares memory with prod, this will matter later on. - (groups, r, c) = prod.shape - (groups_stride, r_stride, c_stride) = prod.stride() - prod_diag = torch.as_strided(prod, size=(groups, r), stride=(groups_stride, r_stride+c_stride)) - # inverse_alpha: (groups, 1) - inverse_alpha = (prod ** 2).sum(dim=2).mean(dim=1, keepdim=True) / prod_diag.mean(dim=1, keepdim=True) - - do_print = random.random() < 0.005 - if do_print: - eye = torch.eye(prod.shape[1], device=prod.device, dtype=prod.dtype).unsqueeze(0) - loss = ((prod / inverse_alpha.unsqueeze(-1) - eye) ** 2).mean(dim=(1,2)) * prod.shape[1] - - # OK, imagining: - # err = 0.5 * ((prod ** 2).sum() * (alpha ** 2) - # - 2 * alpha * prod.diag().sum() - # = prod.shape[0]) - # do d(err)/d(prod) = (alpha**2) * prod - alpha * I - #... and we'll be normalizing out any scalar factor anyway, so by dividing - # by alpha**2 we can treat it as: - # d(err)/d(prod) = prod - I / alpha - prod_deriv = prod - prod_deriv_diag = prod_diag # since prod_deriv shares memory with prod. - prod_deriv_diag.add_(-inverse_alpha) # modifies prod_deriv in place - - # manually differentiate backward through computation of prod: - if w.shape[1] > w.shape[2]: - w_grad = torch.matmul(w, prod_deriv) - else: - w_grad = torch.matmul(prod_deriv, w) - - - # we want to scale w_grad to have the same average element absolute-value - # as err_rel_scale times the average absolute-value in the - # corresponding part of weight_grad: this will make sure it has - # about the required magnitude without overwhelming the main loss. - eps = torch.finfo(w_grad.dtype).tiny - - # err_rel_scale is the scale of the orthogonality loss function relative to the - # average scale of the "main" gradient term. - err_rel_scale = 1000.0 - err_scale = err_rel_scale * w_orig_grad.abs().mean(dim=(1,2), keepdim=True) - - if do_print: - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={inverse_alpha.sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, err_scale={err_scale.flatten()}") - - w_grad = w_grad * err_scale - - # now manually differentiate backward through the expression that computed w in the - # if-elif-else statement above. - if out_groups > 0: - d = out_groups * group_size - weight_grad2 = w_grad.reshape(d, -1) - if d < weight.shape[0]: - z = torch.zeros(weight.shape[0] - d, weight.shape[1], dtype=weight_grad2.dtype, device=weight_grad2.device) - weight_grad2 = torch.cat((weight_grad2, z), dim=0) - elif in_groups > 0: - d = in_groups * group_size - weight_grad2 = w_grad.transpose(0, 1).reshape(weight.shape[0], d) - if d < weight.shape[1]: - z = torch.zeros(weight.shape[0], weight.shape[1] - d, dtype=weight_grad2.dtype, device=weight_grad2.device) - weight_grad2 = torch.cat((weight_grad2, z), dim=1) - else: - weight_grad2 = w_grad.squeeze(0) + with torch.enable_grad(): + weight = weight.detach() + weight.requires_grad = True + + # Get extra gradient term that penalizes non-orthogonality. + + # First get w which is of shape (num_groups, out_channels_per_group, in_channels_per_group) + if out_groups > 0: + w = weight[:out_groups*group_size].reshape(out_groups, group_size, weight.shape[1]) + elif in_groups > 0: + w = weight[:, :in_groups*group_size].reshape(weight.shape[0], in_groups, group_size).transpose(0, 1) + else: + w = weight.unsqueeze(0) + + + # Compute symmetric matrix-product prod with the smallest + # dimension possible given the shape of w. This is not just for + # efficiency; if we computed it the wrong way round, the product + # would have deficient rank and could never be the identity. + if (w.shape[1] > w.shape[2]): + prod = torch.matmul(w.transpose(1, 2), w) + else: + prod = torch.matmul(w, w.transpose(1, 2)) + + # we'll try to enforce that for any i, prod[i] is any constant times the identity. + + # in the loss-function: + # orthogonality_loss = ((prod * alpha - I) ** 2).sum(), + # the following formula gives the alpha that means d(err)/d(scale-of-prod) will be zero. + # alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) + + # note, prod_diag shares memory with prod, this will matter later on. + (groups, r, c) = prod.shape + (groups_stride, r_stride, c_stride) = prod.stride() + + def diag_inplace(z): + return torch.as_strided(z, size=(groups, r), stride=(groups_stride, r_stride+c_stride)) + + with torch.no_grad(): + # alpha: (groups, 1) + alpha = (diag_inplace(prod).mean(dim=1, keepdim=True) / + (prod ** 2).sum(dim=2).mean(dim=1, keepdim=True)) + + prod *= alpha.unsqueeze(-1) + diag_inplace(prod)[:] -= 1. + + # that loss that we want to backprop would be 0.5 * (prod ** + # 2).sum() * penalty_scale. we can backprop this without doing + # any reductions as follows: + prod.backward(gradient=prod * penalty_scale) + + + do_print = random.random() < 0.005 + if do_print: + # we print a normalized version of the loss, by dividing by the + # number of rows. + loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, penalty_scale={penalty_scale}") - weight_grad.add_(weight_grad2) + # add the extra gradient term from the orthogonality loss. + weight_grad += weight.grad else: weight_grad = None return x_grad, weight_grad, None, None, None, None From ab8da9d85bff1c7c0ea85bb4f528ad969c79ee43 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 7 Mar 2025 11:16:00 +0800 Subject: [PATCH 0238/1191] detach in printout. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 236c5f0117..31f63853e5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -636,7 +636,7 @@ def diag_inplace(z): # we print a normalized version of the loss, by dividing by the # number of rows. loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.cpu().flatten()}, penalty_scale={penalty_scale}") + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}") # add the extra gradient term from the orthogonality loss. weight_grad += weight.grad From cb16ae5378d8e92e1684e4b5ec33a032e5d26ef2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Mar 2025 16:49:02 +0800 Subject: [PATCH 0239/1191] Increase power in adam change from .2 to .3 so it's like power 0.8->0.7 in adam; increase factor min,max range from .8,1.2 to .7,1.3. --- egs/librispeech/ASR/zipformer/optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 25d011f210..914a11f681 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -147,9 +147,9 @@ def basic_step(group, p, state, grad): denom = exp_avg_sq.sqrt().add_(eps) ## following three are tunable. - power = 0.2 - factor_max = 1.2 - factor_min = 0.8 + power = 0.3 + factor_max = 1.3 + factor_min = 0.7 factor = ((denom / denom.mean()) ** power).clamp_(min=factor_min, max=factor_max) return -lr * grad * factor / denom From 2b230b37bb56ef91da7a64feb307c19eec1707fa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Mar 2025 20:23:29 +0800 Subject: [PATCH 0240/1191] Restore proj.lr_scale, at 0.75; restore exp_norm initial scale to 1.7; rename BiasNorm to ExpNorm and document it. --- egs/librispeech/ASR/zipformer/my_profile.py | 4 +- egs/librispeech/ASR/zipformer/scaling.py | 39 +++++++++----------- egs/librispeech/ASR/zipformer/subsampling.py | 4 +- egs/librispeech/ASR/zipformer/zipformer.py | 10 ++++- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/my_profile.py b/egs/librispeech/ASR/zipformer/my_profile.py index 7e1fd777a3..f87613eb08 100755 --- a/egs/librispeech/ASR/zipformer/my_profile.py +++ b/egs/librispeech/ASR/zipformer/my_profile.py @@ -26,7 +26,7 @@ import sentencepiece as spm import torch -from scaling import BiasNorm +from scaling import ExpNorm from torch import Tensor, nn from train import ( add_model_arguments, @@ -81,7 +81,7 @@ def _bypass_module_flops_compute(module, input, output): MODULE_HOOK_MAPPING = { - BiasNorm: _bias_norm_flops_compute, + ExpNorm: _bias_norm_flops_compute, BypassModule: _bypass_module_flops_compute, } diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 31f63853e5..5ca765c10a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -366,7 +366,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - scales = (1. - (-x_norm).exp()) / x_norm # torch.log1p(x_norm) / x_norm + scales = (1. - (-x_norm).exp()) / x_norm scales = scale * scales return (x * scales) @@ -415,24 +415,25 @@ def c(x): -class BiasNorm(torch.nn.Module): +class ExpNorm(torch.nn.Module): """ - Comment not up-to-date. This is ExpNorm. Will change docs later if - promising. - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. + LayerNorm, without the learned weight or bias. There is just one learned + parameter, a scalar, which is a scale on the output; and it is limited + during training to the range [0.5..2.5]. + + Unlike LayerNorm it does not pick the scale that maps any rms value at the + input to an rms value of 1 at the output, i.e. the function f(x) = 1 (which + discards the length information); instead, it uses the function: + f(x) = scale * (1 - (-x).exp()), + i.e. if the input rms value was x, it gets mapped to the f(x) above. The + implementation is just: - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization, in addition to a separate trainable - "eps" parameter, learned in log-space. We also give it a (scalar) - trainable scale on the output. + x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() + scales = (1. - (-x_norm).exp()) / x_norm + return (x * scale * scales) + where 'scale' is a scalar, and the only learned parameter. Args: num_channels: the number of channels, e.g. 512. @@ -440,20 +441,16 @@ class BiasNorm(torch.nn.Module): interpreted as an offset from the input's ndim if negative. This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale """ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. ) -> None: - super(BiasNorm, self).__init__() + super(ExpNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(2.0)) + self.scale = nn.Parameter(torch.tensor(1.7)) self.name = None diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 1b90826e1f..7763e16a56 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -24,7 +24,7 @@ Balancer, ScaleLimiter, ScaledLinear, - BiasNorm, + ExpNorm, Dropout3, FloatLike, Optional, @@ -235,7 +235,7 @@ def __init__( # max_log_eps=0.0 is to prevent both eps and the output of self.out from # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) + self.out_norm = ExpNorm(out_channels) self.dropout = Dropout3(dropout, shared_dim=1) def forward( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8a1bd05a9b..37d6c19fe8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -31,7 +31,7 @@ ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaleLimiter, ActivationDropoutAndLinear, - BiasNorm, + ExpNorm, ChunkCausalDepthwiseConv1d, Dropout2, FloatLike, @@ -522,7 +522,7 @@ def __init__( self.scale_limiter = ScaleLimiter(max_scale=2.0) - self.norm = BiasNorm(embed_dim) + self.norm = ExpNorm(embed_dim) def forward( @@ -939,6 +939,9 @@ def __init__( super().__init__() assert proj_dim <= channels * 2 self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 self.causal = causal def forward(self, src: Tensor) -> Tensor: @@ -981,6 +984,9 @@ def __init__(self, channels: int, proj_dim: int): super().__init__() assert proj_dim <= channels self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 def forward(self, src: Tensor) -> Tensor: From 6feb735105b01ce457619bd037af164beb5cddb3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 11 Mar 2025 13:01:20 +0800 Subject: [PATCH 0241/1191] Add copy_bypass to ZipformerEncoder, for diagnostics. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 37d6c19fe8..bf5409e251 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -730,6 +730,7 @@ def __init__( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.copy_bypass = nn.Identity() # in case we are dumping diagnostics. self.whiten = Whiten( num_groups=1, @@ -783,6 +784,7 @@ def forward( src = self.whiten(src) if num_channels > layer_dim: + bypass = self.copy_bypass(bypass) src = torch.cat((src, bypass), dim=-1) return src From a883740832343b99571398918c8e6de790689cfe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Mar 2025 18:45:45 +0800 Subject: [PATCH 0242/1191] Bug fix RE which dims are averaged in adam-with-power. --- egs/librispeech/ASR/zipformer/optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 914a11f681..16039148a0 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -150,7 +150,8 @@ def basic_step(group, p, state, grad): power = 0.3 factor_max = 1.3 factor_min = 0.7 - factor = ((denom / denom.mean()) ** power).clamp_(min=factor_min, max=factor_max) + dims = tuple(range(1, denom.ndim)) + factor = ((denom / denom.mean(dim=dims, keepdim=True)) ** power).clamp_(min=factor_min, max=factor_max) return -lr * grad * factor / denom From 6d4c0e8161da69cf6b35e0939735482df72bf5b3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Mar 2025 14:04:11 +0800 Subject: [PATCH 0243/1191] Change whiten prob in zipformer encoder to 1. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bf5409e251..2e33c30b0e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -735,7 +735,7 @@ def __init__( self.whiten = Whiten( num_groups=1, whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), + prob=(1, 1), grad_scale=0.025, ) From 4bc29b9f1f873e84ff74785a3b62f2c5bceb50f7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Mar 2025 15:36:12 +0800 Subject: [PATCH 0244/1191] Revert optim.py to 315conv, take out power stuff but keep whitening prob change --- egs/librispeech/ASR/zipformer/optim.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 16039148a0..895e3a70fd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -146,14 +146,7 @@ def basic_step(group, p, state, grad): exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) denom = exp_avg_sq.sqrt().add_(eps) - ## following three are tunable. - power = 0.3 - factor_max = 1.3 - factor_min = 0.7 - dims = tuple(range(1, denom.ndim)) - factor = ((denom / denom.mean(dim=dims, keepdim=True)) ** power).clamp_(min=factor_min, max=factor_max) - - return -lr * grad * factor / denom + return -lr * grad / denom def scaling_step(group, p, state, grad): From 48e6e9c443a5fe3c71f85c38b349f71a0defd76b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Mar 2025 15:53:15 +0800 Subject: [PATCH 0245/1191] Implement multiple-betas decay. --- egs/librispeech/ASR/zipformer/optim.py | 37 +++++++++++++++++++++----- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 895e3a70fd..ae785c240f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -261,18 +261,41 @@ def scaling_step(group, p, state, grad): def momentum_step(group, p, state, grad): delta = scaling_step(group, p, state, grad) - beta1 = group["betas"][0] + #beta1 = group["betas"][0] + + # hardcode betas. + # see simulate_params.py on my laptop for how I got these settings. + try: - stored_delta = state["delta"] + stored_delta1 = state["delta1"] + stored_delta2 = state["delta2"] except KeyError: - stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["delta"] = stored_delta - stored_delta.mul_(beta1) - stored_delta.add_(delta, alpha=(1-beta1)) + stored_delta1 = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + stored_delta2 = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta1"] = stored_delta1 + state["delta2"] = stored_delta2 + + #scales=(0.9, -0.075, 0.175): alpha=0.1, lr=0.04, beta=(0.9999, 0.999, 0), data_var=0.05122422448145114 + + # caution, these are not the same as the beta1,beta2 in adam, they are betas for decay of + # different time periods. + step = state["step"] + beta2 = min(0.999, 1. / (step + 10)) + beta1 = 1. - 0.1 * (1. - beta2) + + + scale1 = 0.9 + scale2 = -0.075 + scale_direct = 1. - scale1 - scale2 + + stored_delta1.mul_(beta1) + stored_delta1.add_(delta, alpha=(1-beta1)) + stored_delta2.mul_(beta2) + stored_delta2.add_(delta, alpha=(1-beta2)) # we don't bother doing the "bias correction" part of Adam for beta1 because this is just # an edge effect that affects the first 10 or so batches; and the effect of not doing it # is just to do a slower update for the first few batches, which will help stability. - return stored_delta + return scale_direct * delta + scale1 * stored_delta1 + scale2 * stored_delta2 From 8e663cc8a94807fb6caac7a269b5d936f3d146bc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Mar 2025 17:46:10 +0800 Subject: [PATCH 0246/1191] Fix formulas so we are actualy doing what I intended. --- egs/librispeech/ASR/zipformer/optim.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ae785c240f..dcaa45d22e 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -280,8 +280,9 @@ def momentum_step(group, p, state, grad): # caution, these are not the same as the beta1,beta2 in adam, they are betas for decay of # different time periods. step = state["step"] - beta2 = min(0.999, 1. / (step + 10)) - beta1 = 1. - 0.1 * (1. - beta2) + beta1 = 0.9999 # min(0.9999, 1. - 1. / (2 * step + 100)) + beta2 = 0.999 #1. - 10.0 * (1. - beta1) + assert beta2 > 0, (beta2, beta1) scale1 = 0.9 @@ -1310,7 +1311,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.1, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From d0c8d441d6ef7fd1d94c67edda063118e4353bd7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 16:17:47 +0800 Subject: [PATCH 0247/1191] commit some temp changes --- egs/librispeech/ASR/zipformer/optim.py | 71 +++++++++++++++++--------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index dcaa45d22e..7df16fd32a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -261,43 +261,66 @@ def scaling_step(group, p, state, grad): def momentum_step(group, p, state, grad): delta = scaling_step(group, p, state, grad) + #beta1 = group["betas"][0] # hardcode betas. # see simulate_params.py on my laptop for how I got these settings. try: - stored_delta1 = state["delta1"] - stored_delta2 = state["delta2"] + stored_delta = state["delta"] + momentum_rate = state["momentum_rate"] except KeyError: - stored_delta1 = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - stored_delta2 = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["delta1"] = stored_delta1 - state["delta2"] = stored_delta2 + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + momentum_rate = 0.01 * torch.ones(p.shape[0], *([1] * (p.ndim-1)), device=p.device, dtype=torch.float) + state["delta"] = stored_delta + state["momentum_rate"] = momentum_rate + + + if p.numel() == p.shape[0]: + # scalar. use conventional momentum. + beta = 0.9 + stored_delta.mul_(beta).add(delta, alpha=(1-beta)) + return stored_delta + + + + - #scales=(0.9, -0.075, 0.175): alpha=0.1, lr=0.04, beta=(0.9999, 0.999, 0), data_var=0.05122422448145114 - # caution, these are not the same as the beta1,beta2 in adam, they are betas for decay of - # different time periods. + + lr = group["lr"] step = state["step"] - beta1 = 0.9999 # min(0.9999, 1. - 1. / (2 * step + 100)) - beta2 = 0.999 #1. - 10.0 * (1. - beta1) - assert beta2 > 0, (beta2, beta1) + # decay near beginning as early grads may change fast. + stored_delta.mul_(1. - 1 / (10 + step)) + stored_delta += delta + + + if step > 200: + # 200 is twice the inverse of the initial/default momentum_rate. + + # grad_scale tells us how large the grad is relative to a single frame's worth of grad (of expected + # magnitude) + grad_scale = torch.mean(stored_delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / (lr ** 2) + + + factor = 0.2 + target_grad_scale = factor / (momentum_rate ** 0.8) + grad_too_large = (grad_scale > target_grad_scale) + # if grad is too large we may have to decrease epsilon. but very slowly. + + adapt_momentum_eps = 0.2 / (100 + step ** 0.8) + momentum_rate *= torch.where(grad_too_large, + 1. - adapt_momentum_eps, + 1. + adapt_momentum_eps) + momentum_rate.clamp_(max=0.1) - scale1 = 0.9 - scale2 = -0.075 - scale_direct = 1. - scale1 - scale2 - stored_delta1.mul_(beta1) - stored_delta1.add_(delta, alpha=(1-beta1)) - stored_delta2.mul_(beta2) - stored_delta2.add_(delta, alpha=(1-beta2)) - # we don't bother doing the "bias correction" part of Adam for beta1 because this is just - # an edge effect that affects the first 10 or so batches; and the effect of not doing it - # is just to do a slower update for the first few batches, which will help stability. - return scale_direct * delta + scale1 * stored_delta1 + scale2 * stored_delta2 + if random.random() < 0.001: + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate}") + return delta + momentum_rate * stored_delta def debug_step(group, p, state, grad): @@ -1311,7 +1334,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.1, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.015, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From 1df056dafb090695ec001a51c2ff51f12735ffd9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 16:31:32 +0800 Subject: [PATCH 0248/1191] Commit verison that gets to loss 0.15 in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7df16fd32a..32503ba8ef 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -272,6 +272,7 @@ def momentum_step(group, p, state, grad): momentum_rate = state["momentum_rate"] except KeyError: stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + stored_delta_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) momentum_rate = 0.01 * torch.ones(p.shape[0], *([1] * (p.ndim-1)), device=p.device, dtype=torch.float) state["delta"] = stored_delta state["momentum_rate"] = momentum_rate @@ -293,7 +294,8 @@ def momentum_step(group, p, state, grad): step = state["step"] # decay near beginning as early grads may change fast. - stored_delta.mul_(1. - 1 / (10 + step)) + beta = 1. - 1 / (10 + step) + stored_delta.mul_(beta) stored_delta += delta @@ -302,7 +304,9 @@ def momentum_step(group, p, state, grad): # grad_scale tells us how large the grad is relative to a single frame's worth of grad (of expected # magnitude) - grad_scale = torch.mean(stored_delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / (lr ** 2) + eps = 1.0e-20 + grad_scale = (torch.mean(stored_delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / + (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) factor = 0.2 @@ -318,7 +322,7 @@ def momentum_step(group, p, state, grad): if random.random() < 0.001: - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate}") + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate}") return delta + momentum_rate * stored_delta @@ -1334,7 +1338,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.015, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.015, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From d57e43736c868bf92869f355d450b7ba695feda7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 16:38:09 +0800 Subject: [PATCH 0249/1191] Implement an optimizer with no-decay momentum, adaptive momentum rate. --- egs/librispeech/ASR/zipformer/optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 32503ba8ef..78d8126f33 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -309,8 +309,8 @@ def momentum_step(group, p, state, grad): (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) - factor = 0.2 - target_grad_scale = factor / (momentum_rate ** 0.8) + factor = 0.25 + target_grad_scale = factor / momentum_rate grad_too_large = (grad_scale > target_grad_scale) # if grad is too large we may have to decrease epsilon. but very slowly. @@ -321,7 +321,7 @@ def momentum_step(group, p, state, grad): momentum_rate.clamp_(max=0.1) - if random.random() < 0.001: + if random.random() < 0.0002: logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate}") return delta + momentum_rate * stored_delta From 17b60cc95df58acf40497331960ccd5298dccb66 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 16:41:25 +0800 Subject: [PATCH 0250/1191] Multiply scalar update by 5. --- egs/librispeech/ASR/zipformer/optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 78d8126f33..1bd2128913 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -282,7 +282,10 @@ def momentum_step(group, p, state, grad): # scalar. use conventional momentum. beta = 0.9 stored_delta.mul_(beta).add(delta, alpha=(1-beta)) - return stored_delta + # mul by 5 because this optimizer expects about 5 times smaller + # learning rates, the user-provided LR being just the non-momentum part of the LR. + # we will clean this up later. + return 5.0 * stored_delta From fa33000fed4bd47ed817645d59cdaac8939eb1d5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 16:53:57 +0800 Subject: [PATCH 0251/1191] Flatten when printingh --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 1bd2128913..e335f063dd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -325,7 +325,7 @@ def momentum_step(group, p, state, grad): if random.random() < 0.0002: - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate}") + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate.flatten()}") return delta + momentum_rate * stored_delta From 6ef578cd71d66b2d3a200276e04698a5c1b72967 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 17:09:34 +0800 Subject: [PATCH 0252/1191] Different way of setting eps, as 0.25 * corr. --- egs/librispeech/ASR/zipformer/optim.py | 46 +++++++++++++++++--------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index e335f063dd..384bd96968 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -269,12 +269,14 @@ def momentum_step(group, p, state, grad): try: stored_delta = state["delta"] + prev_delta = state["prev_delta"] momentum_rate = state["momentum_rate"] except KeyError: stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - stored_delta_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + prev_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) momentum_rate = 0.01 * torch.ones(p.shape[0], *([1] * (p.ndim-1)), device=p.device, dtype=torch.float) state["delta"] = stored_delta + state["prev_delta"] = prev_delta state["momentum_rate"] = momentum_rate @@ -291,8 +293,6 @@ def momentum_step(group, p, state, grad): - - lr = group["lr"] step = state["step"] @@ -301,31 +301,47 @@ def momentum_step(group, p, state, grad): stored_delta.mul_(beta) stored_delta += delta - if step > 200: + # 200 is twice the inverse of the initial/default momentum_rate. # grad_scale tells us how large the grad is relative to a single frame's worth of grad (of expected # magnitude) eps = 1.0e-20 - grad_scale = (torch.mean(stored_delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / - (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + # gives us an idea of lr * alpha, where alpha is the mean-of-diagonal of 2nd deriv of loss function + + delta_corr = -(torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / + (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + # target for eps will be factor * delta_corr. we can try tuning this, it will + # likely be important. factor = 0.25 - target_grad_scale = factor / momentum_rate - grad_too_large = (grad_scale > target_grad_scale) + adapt_momentum_eps = 1.0 / (100 + step ** 0.8) + rate_target = factor * delta_corr + momentum_rate.mul_(1. - adapt_momentum_eps) + momentum_rate.add_(rate_target, alpha=adapt_momentum_eps) + momentum_rate.clamp_(min=0.0001, max=0.1) + + if random.random() < 0.0002: + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, delta_corr={delta_corr.flatten().to('cpu')}, rate_target={rate_target.flatten().to('cpu')}, rate={momentum_rate.flatten().to('cpu')}, eps={eps}") + + #grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate.flatten()}") + + + + #factor = 0.25 + #target_grad_scale = factor / momentum_rate + #grad_too_large = (grad_scale > target_grad_scale) # if grad is too large we may have to decrease epsilon. but very slowly. - adapt_momentum_eps = 0.2 / (100 + step ** 0.8) - momentum_rate *= torch.where(grad_too_large, - 1. - adapt_momentum_eps, - 1. + adapt_momentum_eps) - momentum_rate.clamp_(max=0.1) + #momentum_rate *= torch.where(grad_too_large, + # 1. - adapt_momentum_eps, + # 1. + adapt_momentum_eps) + #momentum_rate.clamp_(max=0.1) - if random.random() < 0.0002: - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate.flatten()}") + prev_delta.copy_(delta) return delta + momentum_rate * stored_delta From 9157ef50ca0e02ff975428349cdce181953da943 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 17:30:53 +0800 Subject: [PATCH 0253/1191] Reduce max rate from .1 to .01 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 384bd96968..c51c7a83de 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -321,7 +321,7 @@ def momentum_step(group, p, state, grad): rate_target = factor * delta_corr momentum_rate.mul_(1. - adapt_momentum_eps) momentum_rate.add_(rate_target, alpha=adapt_momentum_eps) - momentum_rate.clamp_(min=0.0001, max=0.1) + momentum_rate.clamp_(min=0.0001, max=0.01) if random.random() < 0.0002: logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, delta_corr={delta_corr.flatten().to('cpu')}, rate_target={rate_target.flatten().to('cpu')}, rate={momentum_rate.flatten().to('cpu')}, eps={eps}") From 4802aaf54b0bb9bf5f185b6886dc4fdcf4f49267 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 17:44:59 +0800 Subject: [PATCH 0254/1191] have max momentum rate at 1.0, may help stability --- egs/librispeech/ASR/zipformer/optim.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c51c7a83de..c9727a5231 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -321,7 +321,7 @@ def momentum_step(group, p, state, grad): rate_target = factor * delta_corr momentum_rate.mul_(1. - adapt_momentum_eps) momentum_rate.add_(rate_target, alpha=adapt_momentum_eps) - momentum_rate.clamp_(min=0.0001, max=0.01) + momentum_rate.clamp_(min=0.0001, max=1.0) if random.random() < 0.0002: logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, delta_corr={delta_corr.flatten().to('cpu')}, rate_target={rate_target.flatten().to('cpu')}, rate={momentum_rate.flatten().to('cpu')}, eps={eps}") @@ -340,10 +340,12 @@ def momentum_step(group, p, state, grad): # 1. + adapt_momentum_eps) #momentum_rate.clamp_(max=0.1) + # the 0.5 * (prev_delta + delta) is a very basic, dumb momentum that is to stop + # divergence. prev_delta.copy_(delta) - - return delta + momentum_rate * stored_delta + ans = delta + momentum_rate * stored_delta + return ans def debug_step(group, p, state, grad): From 1540320b5d13fc31525ac1b01a16685a5c162f6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 19:03:45 +0800 Subject: [PATCH 0255/1191] Change momentum_rate to a fixed function of step. --- egs/librispeech/ASR/zipformer/optim.py | 41 +++++--------------------- 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c9727a5231..670df6ed2c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -270,14 +270,12 @@ def momentum_step(group, p, state, grad): try: stored_delta = state["delta"] prev_delta = state["prev_delta"] - momentum_rate = state["momentum_rate"] except KeyError: stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) prev_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - momentum_rate = 0.01 * torch.ones(p.shape[0], *([1] * (p.ndim-1)), device=p.device, dtype=torch.float) state["delta"] = stored_delta state["prev_delta"] = prev_delta - state["momentum_rate"] = momentum_rate + if p.numel() == p.shape[0]: @@ -301,44 +299,19 @@ def momentum_step(group, p, state, grad): stored_delta.mul_(beta) stored_delta += delta - if step > 200: - - # 200 is twice the inverse of the initial/default momentum_rate. - - # grad_scale tells us how large the grad is relative to a single frame's worth of grad (of expected - # magnitude) - eps = 1.0e-20 - # gives us an idea of lr * alpha, where alpha is the mean-of-diagonal of 2nd deriv of loss function - delta_corr = -(torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / - (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + momentum_rate = 3.0 / (100 + step ** 0.8) + if random.random() < 0.0002: - # target for eps will be factor * delta_corr. we can try tuning this, it will - # likely be important. - factor = 0.25 - adapt_momentum_eps = 1.0 / (100 + step ** 0.8) - rate_target = factor * delta_corr - momentum_rate.mul_(1. - adapt_momentum_eps) - momentum_rate.add_(rate_target, alpha=adapt_momentum_eps) - momentum_rate.clamp_(min=0.0001, max=1.0) - - if random.random() < 0.0002: - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, delta_corr={delta_corr.flatten().to('cpu')}, rate_target={rate_target.flatten().to('cpu')}, rate={momentum_rate.flatten().to('cpu')}, eps={eps}") - - #grad_scale={grad_scale.flatten().to('cpu')}, target_grad_scale={target_grad_scale.flatten().to('cpu')}, inv_momentum_rate={1/momentum_rate.flatten()}") + eps = 1.0e-20 + delta_corr = (torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / + (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={delta_corr.flatten().to('cpu')}") - #factor = 0.25 - #target_grad_scale = factor / momentum_rate - #grad_too_large = (grad_scale > target_grad_scale) - # if grad is too large we may have to decrease epsilon. but very slowly. - #momentum_rate *= torch.where(grad_too_large, - # 1. - adapt_momentum_eps, - # 1. + adapt_momentum_eps) - #momentum_rate.clamp_(max=0.1) # the 0.5 * (prev_delta + delta) is a very basic, dumb momentum that is to stop # divergence. From 3191c093e7f1cf588bfead102558429d00d520df Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Mar 2025 19:52:19 +0800 Subject: [PATCH 0256/1191] Add a short-time momentum with beta=0.8 to avoid divergence; fix bug regarding scalars. --- egs/librispeech/ASR/zipformer/optim.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 670df6ed2c..65259ec746 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -268,20 +268,22 @@ def momentum_step(group, p, state, grad): # see simulate_params.py on my laptop for how I got these settings. try: + summed_grad = state["summed_grad"] stored_delta = state["delta"] prev_delta = state["prev_delta"] except KeyError: + summed_grad = torch.zeros(*p.shape, device=p.device, dtype=torch.float) stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) prev_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["summed_grad"] = summed_grad state["delta"] = stored_delta state["prev_delta"] = prev_delta - if p.numel() == p.shape[0]: # scalar. use conventional momentum. beta = 0.9 - stored_delta.mul_(beta).add(delta, alpha=(1-beta)) + stored_delta.mul_(beta).add_(delta, alpha=(1-beta)) # mul by 5 because this optimizer expects about 5 times smaller # learning rates, the user-provided LR being just the non-momentum part of the LR. # we will clean this up later. @@ -290,14 +292,13 @@ def momentum_step(group, p, state, grad): - lr = group["lr"] step = state["step"] # decay near beginning as early grads may change fast. beta = 1. - 1 / (10 + step) - stored_delta.mul_(beta) - stored_delta += delta + summed_grad.mul_(beta) + summed_grad += delta momentum_rate = 3.0 / (100 + step ** 0.8) @@ -317,8 +318,13 @@ def momentum_step(group, p, state, grad): # divergence. prev_delta.copy_(delta) - ans = delta + momentum_rate * stored_delta - return ans + + + momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. + stored_delta.mul_(momentum_beta) + stored_delta.add_(delta + momentum_rate * summed_grad, alpha=(1-momentum_beta)) + + return stored_delta def debug_step(group, p, state, grad): From 88a90604488231795c7bba562863f8ce70514f40 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 11:45:45 +0800 Subject: [PATCH 0257/1191] Version that works OK in optim.py test --- egs/librispeech/ASR/zipformer/optim.py | 96 ++++++++++---------------- 1 file changed, 38 insertions(+), 58 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 65259ec746..55e7027a87 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -122,6 +122,7 @@ def batched_params(self, param_group, group_params_names): + def basic_step(group, p, state, grad): # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. lr = group["lr"] @@ -149,6 +150,7 @@ def basic_step(group, p, state, grad): return -lr * grad / denom + def scaling_step(group, p, state, grad): delta = basic_step(group, p, state, grad) if p.numel() == p.shape[0]: @@ -197,11 +199,16 @@ def scaling_step(group, p, state, grad): if step % size_update_period == size_update_period - 1 and step > 0: # This block updates the size of parameter by adding a step ("delta") value in - # the direction of either shrinking or growing it. + # the direction of either shrinking or growing it. it also includes a modified + # form of adamw-like shrinkage, which we modify a bit to ensure there is a unique + # optimum for the scales (since thanks to the "delta *= param_rms.clamp(min=min_rms)", + # there is no longer the size-stabilizing phenomenon as in Adam whereby parameters with smaller + # rms will tend to grow faster thanks to parameter noise). beta2 = group["betas"][1] size_lr = group["lr"] * group["scalar_lr_scale"] - max_rms = group["weight_max_rms"] if p.ndim > 2 else group["bias_max_rms"] + penalty_rms = group["weight_penalty_rms"] if p.ndim > 2 else group["bias_penalty_rms"] eps = group["eps"] + decay_scale = group["decay_scale"] batch_size = p.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. @@ -217,44 +224,15 @@ def scaling_step(group, p, state, grad): denom = scale_exp_avg_sq.sqrt() + eps - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + scale_norm_grad = ( + (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom ) + # add AdamW-like decay, if we are above "penalty_rms" which is not a hard + # maximum but a cutoff for introducing decay. + scale_norm_grad = scale_norm_grad + decay_scale * (param_rms / penalty_rms).log().clamp_(min=0.0) + scale_step = -size_lr * scale_norm_grad - # turn off the scale-step once param_rms is below min_rms, scale becomes - # 1.0 once we are twice param_min_rms. - scale_step_factor = ((param_rms / min_rms) - 1.).clamp_(min=0.0, max=1.0) - - # The following may help prevent instability: don't allow the scale step to be too large in - # either direction. - # TODO: remove this. - scale_step.clamp_(min=-0.1, max=0.1) - - # and ensure the parameter rms after update never exceeds max_rms. - # We have to look at the trained model for parameters at or around the - # max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (max_rms - param_rms) / param_rms) - - - # (1 + lr**2) ** 0.5 ~ 1 + (0.5 lr**2) would be the factor by which the parameter rms - # increases on each step, assuming the gradient is orthogonal to the current - # parameter value. we cancel this out by subtracting (0.5 * lr**2); we - # need to do this times size_update_period. - - CORRECTION_FACTOR = 0.35 if is_weight else 0.5 - # mathematically this should be 0.5. 0.25 gives less-aggressive shrinkage. give the more-aggressive shrinkage - # of 0.5 for biases, as the biases getting relatively smaller will tend to prevent failure of the grad to propagate. - scale_step = scale_step - (CORRECTION_FACTOR * (group["lr"] ** 2) * size_update_period) - - scale_step = scale_step_factor * scale_step - - # the "+ 0.5 * scale_step ** 2" can be thought of as taking the second - # term in the Taylor expansion of exp(s) - 1, which is s + s^2 / 2!. - # this is so that in effect we are learning the scale in log space, - # so to represent it in p we have to exponentiate it. it's to avoid - # a downward bias in the scale that might otherwise happen. - delta.add_(p * (scale_step + 0.5 * scale_step ** 2)) + delta.add_(p * scale_step) return delta @@ -286,7 +264,7 @@ def momentum_step(group, p, state, grad): stored_delta.mul_(beta).add_(delta, alpha=(1-beta)) # mul by 5 because this optimizer expects about 5 times smaller # learning rates, the user-provided LR being just the non-momentum part of the LR. - # we will clean this up later. + # we will try to find a way to clean this up later. return 5.0 * stored_delta @@ -440,13 +418,12 @@ class ScaledAdam(BatchedOptimizer): weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of learning the scale on the parameters. Weight tensors are defined as anything with more than one element and ndim > 1. - weight_max_rms: Maximum root-mean-square value of weight tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each weight - parameter tensor to be <= this value). + weight_penalty_rms: Value of root-mean-square value of weight tensor, above which we + do adamw-style decay. bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with more than one element and exactly one tensor dimension i.e. ndim == 1. - bias_max_rms: Maximum root-mean-square value of bias tensors, defined as anything with - more than one element and exactly one tensor dimension i.e. ndim == 1. + bias_penalty_rms: Value of root-mean-square value of bias tensor, above which we + do adamw-style decay. scalar_max: Maximum absolute value for scalar parameters (applicable if your model has any parameters with numel() == 1). size_update_period: The periodicity, in steps, with which we update the size (scale) @@ -465,9 +442,10 @@ def __init__( scalar_lr_scale=0.05, eps=1.0e-08, weight_min_rms=0.005, - weight_max_rms=1.0, + weight_penalty_rms=0.05, bias_min_rms=1.0e-05, - bias_max_rms=3.0, + bias_penalty_rms=0.2, + decay_scale=0.02, scalar_max=10.0, size_update_period=4, clipping_update_period=100, @@ -481,9 +459,10 @@ def __init__( scalar_lr_scale=scalar_lr_scale, eps=eps, weight_min_rms=weight_min_rms, - weight_max_rms=weight_max_rms, + weight_penalty_rms=weight_penalty_rms, bias_min_rms=bias_min_rms, - bias_max_rms=bias_max_rms, + bias_penalty_rms=bias_penalty_rms, + decay_scale=decay_scale, scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, @@ -1362,18 +1341,19 @@ def _test_scaled_adam(hidden_dim: int): else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[1].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[3].weight**2).mean().sqrt().item() + norm4 = '%.2e' % (m[5].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[3].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[5].bias**2).mean().sqrt().item() + lr = scheduler.get_last_lr()[0] logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) loss.log().backward() optim.step() optim.zero_grad() From 38fc70518ea3714fec39789c91ed2d991a733f2e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 12:05:33 +0800 Subject: [PATCH 0258/1191] Version that is buggy but works very well in optim.py, loss 0.09 --- egs/librispeech/ASR/zipformer/optim.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 55e7027a87..d52d5b736a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -262,12 +262,7 @@ def momentum_step(group, p, state, grad): # scalar. use conventional momentum. beta = 0.9 stored_delta.mul_(beta).add_(delta, alpha=(1-beta)) - # mul by 5 because this optimizer expects about 5 times smaller - # learning rates, the user-provided LR being just the non-momentum part of the LR. - # we will try to find a way to clean this up later. - return 5.0 * stored_delta - - + return stored_delta lr = group["lr"] @@ -439,7 +434,8 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.05, + scalar_lr_scale=0.25, + scaling_lr_scale=0.05, eps=1.0e-08, weight_min_rms=0.005, weight_penalty_rms=0.05, @@ -457,6 +453,7 @@ def __init__( clipping_scale=clipping_scale, betas=betas, scalar_lr_scale=scalar_lr_scale, + scaling_lr_scale=scaling_lr_scale, eps=eps, weight_min_rms=weight_min_rms, weight_penalty_rms=weight_penalty_rms, From 4462e23e127425604498cb5e7493f3bc2d8703ce Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 12:08:16 +0800 Subject: [PATCH 0259/1191] Cleanup: split scaling and scalar lr factors but still same value. --- egs/librispeech/ASR/zipformer/optim.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d52d5b736a..07cdb9bcfe 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -205,7 +205,7 @@ def scaling_step(group, p, state, grad): # there is no longer the size-stabilizing phenomenon as in Adam whereby parameters with smaller # rms will tend to grow faster thanks to parameter noise). beta2 = group["betas"][1] - size_lr = group["lr"] * group["scalar_lr_scale"] + size_lr = group["lr"] * group["scaling_lr_scale"] penalty_rms = group["weight_penalty_rms"] if p.ndim > 2 else group["bias_penalty_rms"] eps = group["eps"] decay_scale = group["decay_scale"] @@ -405,10 +405,11 @@ class ScaledAdam(BatchedOptimizer): by this quantity. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed + scaling_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each non-scalar parameter tensor. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the scaling factor on the learning rate of p_scale. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. eps: A general-purpose epsilon to prevent division by zero weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of learning the scale on the parameters. Weight tensors are defined @@ -435,7 +436,7 @@ def __init__( clipping_scale=None, betas=(0.9, 0.98), scalar_lr_scale=0.25, - scaling_lr_scale=0.05, + scaling_lr_scale=0.25, eps=1.0e-08, weight_min_rms=0.005, weight_penalty_rms=0.05, From c2f3406d657f0134ac587f3397fdf23cdec236ff Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 12:10:12 +0800 Subject: [PATCH 0260/1191] decrease scaling_lr_scale from .25 to .1; this makes optim.py test worse 0.09->0.25 but may be more stable in large scsale --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 07cdb9bcfe..ca04476d65 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -436,7 +436,7 @@ def __init__( clipping_scale=None, betas=(0.9, 0.98), scalar_lr_scale=0.25, - scaling_lr_scale=0.25, + scaling_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, weight_penalty_rms=0.05, From ee47de0519a92a0b72fd5a1011e18359dcb40d7f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 12:30:03 +0800 Subject: [PATCH 0261/1191] Change momentum_rate to decrease more slowly. --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ca04476d65..72bc174dcf 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -273,8 +273,8 @@ def momentum_step(group, p, state, grad): summed_grad.mul_(beta) summed_grad += delta - - momentum_rate = 3.0 / (100 + step ** 0.8) + # This formula may be important to tune! + momentum_rate = 2.5 / (100 + step ** 0.666) if random.random() < 0.0002: From 0ba152fde857507f73d4f28c9d4381310f4856cb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 16:34:26 +0800 Subject: [PATCH 0262/1191] Implement decay differently to avoid very large params; decay_scale=0.1 --- egs/librispeech/ASR/zipformer/optim.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 72bc174dcf..86a0170178 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -206,7 +206,6 @@ def scaling_step(group, p, state, grad): # rms will tend to grow faster thanks to parameter noise). beta2 = group["betas"][1] size_lr = group["lr"] * group["scaling_lr_scale"] - penalty_rms = group["weight_penalty_rms"] if p.ndim > 2 else group["bias_penalty_rms"] eps = group["eps"] decay_scale = group["decay_scale"] batch_size = p.shape[0] @@ -229,7 +228,7 @@ def scaling_step(group, p, state, grad): ) # add AdamW-like decay, if we are above "penalty_rms" which is not a hard # maximum but a cutoff for introducing decay. - scale_norm_grad = scale_norm_grad + decay_scale * (param_rms / penalty_rms).log().clamp_(min=0.0) + scale_norm_grad = scale_norm_grad + (decay_scale * size_update_period) * param_rms scale_step = -size_lr * scale_norm_grad delta.add_(p * scale_step) @@ -414,12 +413,12 @@ class ScaledAdam(BatchedOptimizer): weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of learning the scale on the parameters. Weight tensors are defined as anything with more than one element and ndim > 1. - weight_penalty_rms: Value of root-mean-square value of weight tensor, above which we - do adamw-style decay. + weight_penalty_rms: Value of root-mean-square value of weight tensor, that provides + a reference point for when we start to do adamw-style decay. bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with more than one element and exactly one tensor dimension i.e. ndim == 1. - bias_penalty_rms: Value of root-mean-square value of bias tensor, above which we - do adamw-style decay. + bias_penalty_rms: Value of root-mean-square value of bias tensor, that provides + a reference point for when we start to do adamw-style decay. scalar_max: Maximum absolute value for scalar parameters (applicable if your model has any parameters with numel() == 1). size_update_period: The periodicity, in steps, with which we update the size (scale) @@ -439,10 +438,8 @@ def __init__( scaling_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, - weight_penalty_rms=0.05, bias_min_rms=1.0e-05, - bias_penalty_rms=0.2, - decay_scale=0.02, + decay_scale=0.1, scalar_max=10.0, size_update_period=4, clipping_update_period=100, @@ -457,9 +454,7 @@ def __init__( scaling_lr_scale=scaling_lr_scale, eps=eps, weight_min_rms=weight_min_rms, - weight_penalty_rms=weight_penalty_rms, bias_min_rms=bias_min_rms, - bias_penalty_rms=bias_penalty_rms, decay_scale=decay_scale, scalar_max=scalar_max, size_update_period=size_update_period, From 305cb43f5cc454db6e3d893d548f5d275b2ce70d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 17:04:47 +0800 Subject: [PATCH 0263/1191] Slower-decreasing formula for momentum_rate --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 86a0170178..24242c23cd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -273,7 +273,7 @@ def momentum_step(group, p, state, grad): summed_grad += delta # This formula may be important to tune! - momentum_rate = 2.5 / (100 + step ** 0.666) + momentum_rate = 2.0 / (100 + step ** 0.5) if random.random() < 0.0002: From 871b597190a54b7cb25d9ac1e512ab9ad1410d4f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 19:14:52 +0800 Subject: [PATCH 0264/1191] Change how the two momentums combine, make it additive. --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 86a0170178..98175c320e 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -294,9 +294,9 @@ def momentum_step(group, p, state, grad): momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. stored_delta.mul_(momentum_beta) - stored_delta.add_(delta + momentum_rate * summed_grad, alpha=(1-momentum_beta)) + stored_delta.add_(delta, alpha=(1-momentum_beta)) - return stored_delta + return stored_delta + momentum_rate * summed_grad def debug_step(group, p, state, grad): From 022634162234bf3a1a0b84dfe6d517b8e7ace245 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 21:03:54 +0800 Subject: [PATCH 0265/1191] Add debug statements; fix test. --- egs/librispeech/ASR/zipformer/optim.py | 6 ++++- egs/librispeech/ASR/zipformer/scaling.py | 31 ++++-------------------- 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 88de39aa9c..c735e65472 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -281,8 +281,12 @@ def momentum_step(group, p, state, grad): delta_corr = (torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + # ratio of var of summed_grad to delta. + var_ratio = (torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / + (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={delta_corr.flatten().to('cpu')}") + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={delta_corr.flatten().to('cpu')}, var_ratio={var_ratio.flatten().to('cpu')}") diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5ca765c10a..5f140e9390 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -633,7 +633,8 @@ def diag_inplace(z): # we print a normalized version of the loss, by dividing by the # number of rows. loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}") + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}, weight_grad_abs_mean={weight_grad.abs().mean()}") + # add the extra gradient term from the orthogonality loss. weight_grad += weight.grad @@ -711,28 +712,6 @@ def forward(self, x: Tensor): return ans -def OrthogonalLinearSpecial(num_channels: int, - penalty_scale: float = 1000.0, - transpose: bool = False): - # returns a parameterized nn.Linear that stays orthogonal, with a special initialization - # that is suitable to use when downsampling; we reshape then multiply by this matrix. - assert num_channels % 2 == 0 - ans = OrthogonalLinear(num_channels, penalty_scale=penalty_scale) - # want to initialize weight as: - # 1/sqrt(2) * M, where M is a block-diagonal matrix with 2x2 blocks [ 1 1; 1 -1 ] - with torch.no_grad(): - inv_sqrt2 = 2 ** -0.5 - ans.weight[:] = 0.0 - ans.weight[0::2, 0::2] = inv_sqrt2 - ans.weight[0::2, 1::2] = inv_sqrt2 - ans.weight[1::2, 0::2] = -inv_sqrt2 if transpose else inv_sqrt2 - ans.weight[1::2, 1::2] = inv_sqrt2 if transpose else -inv_sqrt2 - N = ans.weight.shape[0] - ans.weight *= (torch.arange(N)[:, None] // 2 == - torch.arange(N)[None, :] // 2) - - return ans - class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ @@ -2041,9 +2020,9 @@ def isclose(a, b): assert isclose(x1.grad, x2.grad) def _test_orthogonal_linear(): - for t in (OrthogonalLinear, OrthogonalLinearSpecial): - m = t(128) - m(torch.randn(30, 2, 128)) + t = OrthogonalLinear(128, 128) + m = t(128, 128) + m(torch.randn(30, 2, 128)) if __name__ == "__main__": From db01255f570727d3180d6bc2cd5e5c3a5b7141fc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 21:06:58 +0800 Subject: [PATCH 0266/1191] Bug fix in test --- egs/librispeech/ASR/zipformer/scaling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5f140e9390..1ac6ed1c2e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -2020,8 +2020,7 @@ def isclose(a, b): assert isclose(x1.grad, x2.grad) def _test_orthogonal_linear(): - t = OrthogonalLinear(128, 128) - m = t(128, 128) + m = OrthogonalLinear(128, 128) m(torch.randn(30, 2, 128)) From 1982f20e9b6542508de89f37a7a0050765c1edc5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 22:44:53 +0800 Subject: [PATCH 0267/1191] Remove redundant printout. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1ac6ed1c2e..38ba0f318f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -633,7 +633,7 @@ def diag_inplace(z): # we print a normalized version of the loss, by dividing by the # number of rows. loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}, weight_grad_abs_mean={weight_grad.abs().mean()}") + logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") # add the extra gradient term from the orthogonality loss. From 99288a10ad48b1e92cd9ddbce76b7f256aa0a7bf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 22:45:52 +0800 Subject: [PATCH 0268/1191] Decrease factor in OrthogonalLinearFunction from 1000.0 to 20.0 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 38ba0f318f..e2529dae72 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -574,7 +574,7 @@ def backward(ctx, y_grad): weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), x.reshape(-1, x.shape[-1])) - penalty_scale = 1000.0 * weight_grad.abs().mean() + penalty_scale = 20.0 * weight_grad.abs().mean() with torch.enable_grad(): weight = weight.detach() From 08dcc6167503572117ea0e37bf58242f1ac05291 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 22:50:04 +0800 Subject: [PATCH 0269/1191] Increase decay beta from 1. - 1. / (10 + step) to 1. - 2.5 / (25 + step) --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c735e65472..d09d7638c5 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -267,8 +267,8 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] - # decay near beginning as early grads may change fast. - beta = 1. - 1 / (10 + step) + # decay near the beginning of training, as early grads may change fast. + beta = 1. - 2.5 / (25 + step) summed_grad.mul_(beta) summed_grad += delta From a03406d788f0dccc6295be91b590cbf114e35070 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Mar 2025 23:59:35 +0800 Subject: [PATCH 0270/1191] Revert 392conv: revert momentum_rate from 2.0 / (100 + step ** 0.5) to 2.5 / (100 + step ** 0.666) --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d09d7638c5..a8da2eddf8 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -273,7 +273,7 @@ def momentum_step(group, p, state, grad): summed_grad += delta # This formula may be important to tune! - momentum_rate = 2.0 / (100 + step ** 0.5) + momentum_rate = 2.5 / (100 + step ** 0.666) if random.random() < 0.0002: From 889cdf4424037fff845a00689c4b1b6c35a79130 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Mar 2025 11:32:44 +0800 Subject: [PATCH 0271/1191] Add more printouts. --- egs/librispeech/ASR/zipformer/optim.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a8da2eddf8..c23210f2cb 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -275,18 +275,21 @@ def momentum_step(group, p, state, grad): # This formula may be important to tune! momentum_rate = 2.5 / (100 + step ** 0.666) - if random.random() < 0.0002: + if random.random() < 0.001: eps = 1.0e-20 - delta_corr = (torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / - (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + den = eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) + delta_corr = torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / den - # ratio of var of summed_grad to delta. - var_ratio = (torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / - (eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True))) + summed_grad_corr = torch.mean(delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den + + # ratio of var of summed_grad to delta. + var_ratio = torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / den - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={delta_corr.flatten().to('cpu')}, var_ratio={var_ratio.flatten().to('cpu')}") + def f(x): + return x.flatten().to('cpu') + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={f(delta_corr)}, summed_grad_corr={f(summed_grad_corr)}, summed_grad_corr_scaled={f(summed_grad_corr*momentum_rate)}, var_ratio={f(var_ratio)}") From 72bd3134ba6f065ee28001b515c506c9ee3b7ba2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Mar 2025 11:41:41 +0800 Subject: [PATCH 0272/1191] Add another diagnostic. --- egs/librispeech/ASR/zipformer/optim.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c23210f2cb..0f2d7d9b9c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -275,6 +275,10 @@ def momentum_step(group, p, state, grad): # This formula may be important to tune! momentum_rate = 2.5 / (100 + step ** 0.666) + momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. + stored_delta.mul_(momentum_beta) + stored_delta.add_(delta, alpha=(1-momentum_beta)) + if random.random() < 0.001: eps = 1.0e-20 @@ -283,13 +287,14 @@ def momentum_step(group, p, state, grad): summed_grad_corr = torch.mean(delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den + summed_grad_corr_slow = torch.mean(stored_delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den # ratio of var of summed_grad to delta. var_ratio = torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / den def f(x): return x.flatten().to('cpu') - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={f(delta_corr)}, summed_grad_corr={f(summed_grad_corr)}, summed_grad_corr_scaled={f(summed_grad_corr*momentum_rate)}, var_ratio={f(var_ratio)}") + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={f(delta_corr)}, summed_grad_corr={f(summed_grad_corr)}, summed_grad_corr_slow={f(summed_grad_corr_slow)}, summed_grad_corr_scaled={f(summed_grad_corr*momentum_rate)}, var_ratio={f(var_ratio)}") @@ -299,10 +304,6 @@ def f(x): prev_delta.copy_(delta) - momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. - stored_delta.mul_(momentum_beta) - stored_delta.add_(delta, alpha=(1-momentum_beta)) - return stored_delta + momentum_rate * summed_grad From f65d012ff48e7eff723cd63443a8bfd2c5a9c515 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Mar 2025 16:53:32 +0800 Subject: [PATCH 0273/1191] Put momentum inside scaling update and apply LR after momentum. --- egs/librispeech/ASR/zipformer/optim.py | 154 ++++++++++++------------- 1 file changed, 75 insertions(+), 79 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 0f2d7d9b9c..34c29674c2 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -124,10 +124,7 @@ def batched_params(self, param_group, group_params_names): def basic_step(group, p, state, grad): - # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. - lr = group["lr"] - if p.numel() == p.shape[0]: - lr = lr * group["scalar_lr_scale"] + # computes basic Adam normalized-grad using beta2 (dividing by gradient stddev) only. no momentum yet. beta2 = group["betas"][1] eps = group["eps"] # p shape: (batch_size,) or (batch_size, 1, [1,..]) @@ -147,12 +144,82 @@ def basic_step(group, p, state, grad): exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) denom = exp_avg_sq.sqrt().add_(eps) - return -lr * grad / denom + return grad / denom -def scaling_step(group, p, state, grad): +def momentum_step(group, p, state, grad): delta = basic_step(group, p, state, grad) + + #beta1 = group["betas"][0] + + lr = group["lr"] + step = state["step"] + + try: + summed_grad = state["summed_grad"] + stored_delta = state["delta"] + prev_delta = state["prev_delta"] + except KeyError: + summed_grad = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + prev_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["summed_grad"] = summed_grad + state["delta"] = stored_delta + state["prev_delta"] = prev_delta + + + if p.numel() == p.shape[0]: + # scalar. use conventional momentum. + beta = 0.9 + stored_delta.mul_(beta).add_(delta, alpha=(1-beta)) + lr = lr * group["scalar_lr_scale"] + return -lr * stored_delta + + + + # decay near the beginning of training, as early grads may change fast. + beta = 1. - 1.5 / (15 + step) + summed_grad.mul_(beta) + summed_grad += delta + + # This formula may be important to tune! + momentum_rate = 2.5 / (100 + step ** 0.666) + + momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. + stored_delta.mul_(momentum_beta) + stored_delta.add_(delta, alpha=(1-momentum_beta)) + + if random.random() < 0.001: + + eps = 1.0e-20 + den = eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) + delta_corr = torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / den + + + summed_grad_corr = torch.mean(delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den + summed_grad_corr_slow = torch.mean(stored_delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den + + # ratio of var of summed_grad to delta. + var_ratio = torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / den + + def f(x): + return x.flatten().to('cpu') + logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={f(delta_corr)}, summed_grad_corr={f(summed_grad_corr)}, summed_grad_corr_slow={f(summed_grad_corr_slow)}, summed_grad_corr_scaled={f(summed_grad_corr*momentum_rate)}, var_ratio={f(var_ratio)}") + + + + + prev_delta.copy_(delta) + + return -lr * (stored_delta + momentum_rate * summed_grad) + + + + + +def scaling_step(group, p, state, grad): + delta = momentum_step(group, p, state, grad) if p.numel() == p.shape[0]: return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) @@ -236,83 +303,12 @@ def scaling_step(group, p, state, grad): return delta -def momentum_step(group, p, state, grad): - delta = scaling_step(group, p, state, grad) - - #beta1 = group["betas"][0] - - # hardcode betas. - # see simulate_params.py on my laptop for how I got these settings. - - try: - summed_grad = state["summed_grad"] - stored_delta = state["delta"] - prev_delta = state["prev_delta"] - except KeyError: - summed_grad = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - prev_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["summed_grad"] = summed_grad - state["delta"] = stored_delta - state["prev_delta"] = prev_delta - - - if p.numel() == p.shape[0]: - # scalar. use conventional momentum. - beta = 0.9 - stored_delta.mul_(beta).add_(delta, alpha=(1-beta)) - return stored_delta - - - lr = group["lr"] - step = state["step"] - - # decay near the beginning of training, as early grads may change fast. - beta = 1. - 2.5 / (25 + step) - summed_grad.mul_(beta) - summed_grad += delta - - # This formula may be important to tune! - momentum_rate = 2.5 / (100 + step ** 0.666) - - momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. - stored_delta.mul_(momentum_beta) - stored_delta.add_(delta, alpha=(1-momentum_beta)) - - if random.random() < 0.001: - - eps = 1.0e-20 - den = eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) - delta_corr = torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / den - - - summed_grad_corr = torch.mean(delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den - summed_grad_corr_slow = torch.mean(stored_delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den - - # ratio of var of summed_grad to delta. - var_ratio = torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / den - - def f(x): - return x.flatten().to('cpu') - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={f(delta_corr)}, summed_grad_corr={f(summed_grad_corr)}, summed_grad_corr_slow={f(summed_grad_corr_slow)}, summed_grad_corr_scaled={f(summed_grad_corr*momentum_rate)}, var_ratio={f(var_ratio)}") - - - - # the 0.5 * (prev_delta + delta) is a very basic, dumb momentum that is to stop - # divergence. - - prev_delta.copy_(delta) - - - return stored_delta + momentum_rate * summed_grad - - def debug_step(group, p, state, grad): debug_interval = group["debug_interval"] debug_buffer_size = 256 step = state["step"] - delta = momentum_step(group, p, state, grad) + delta = scaling_step(group, p, state, grad) if debug_interval == 0 or step % debug_interval != 0: return delta @@ -1318,7 +1314,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.015, clipping_scale=2.0, eps=1.0e-20) + optim = ScaledAdam(m.named_parameters(), lr=0.01, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From f27fe672eb6929625bff473e796f5d64312a370d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Mar 2025 11:09:13 +0800 Subject: [PATCH 0274/1191] Version where optim.py works well --- egs/librispeech/ASR/zipformer/optim.py | 69 ++++++++++---------------- 1 file changed, 26 insertions(+), 43 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 34c29674c2..b2bb5491df 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -156,17 +156,28 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] + try: - summed_grad = state["summed_grad"] stored_delta = state["delta"] - prev_delta = state["prev_delta"] + if p.numel() != p.shape[0]: + scales = state["scales"] + betas = state["betas"] except KeyError: - summed_grad = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - prev_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["summed_grad"] = summed_grad - state["delta"] = stored_delta - state["prev_delta"] = prev_delta + if p.numel() == p.shape[0]: + # scalar. use conventional momentum. + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + else: + betas = torch.tensor([ 0.96, 0.9984, 0.999936]).to(device=p.device) + # caution: the scales include the 1-beta factor. + scales = torch.tensor([ 2, 4, 8 ]).to(device=p.device) * (1-betas) + for _ in range(p.ndim): + betas, scales = betas.unsqueeze(-1), scales.unsqueeze(-1) + + stored_delta = torch.zeros(len(betas), *p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + state["betas"] = betas + state["scales"] = scales if p.numel() == p.shape[0]: @@ -178,41 +189,13 @@ def momentum_step(group, p, state, grad): - # decay near the beginning of training, as early grads may change fast. - beta = 1. - 1.5 / (15 + step) - summed_grad.mul_(beta) - summed_grad += delta - - # This formula may be important to tune! - momentum_rate = 2.5 / (100 + step ** 0.666) - - momentum_beta = 0.8 # this is an additional short-time momentum, just to prevent divergence early on. - stored_delta.mul_(momentum_beta) - stored_delta.add_(delta, alpha=(1-momentum_beta)) - - if random.random() < 0.001: - - eps = 1.0e-20 - den = eps + torch.mean(delta ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) - delta_corr = torch.mean(delta * prev_delta, dim=tuple(range(1, p.ndim)), keepdim=True) / den - - - summed_grad_corr = torch.mean(delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den - summed_grad_corr_slow = torch.mean(stored_delta * summed_grad, dim=tuple(range(1, p.ndim)), keepdim=True) / den - - # ratio of var of summed_grad to delta. - var_ratio = torch.mean(summed_grad ** 2, dim=tuple(range(1, p.ndim)), keepdim=True) / den - - def f(x): - return x.flatten().to('cpu') - logging.info(f"step={step}, shape={list(p.shape)}, lr={lr}, momentum_rate={momentum_rate}, delta_corr={f(delta_corr)}, summed_grad_corr={f(summed_grad_corr)}, summed_grad_corr_slow={f(summed_grad_corr_slow)}, summed_grad_corr_scaled={f(summed_grad_corr*momentum_rate)}, var_ratio={f(var_ratio)}") - - - - - prev_delta.copy_(delta) + # an extra decay of the deltas near the beginning of training, as early grads may change fast. + decay = 1. - 1.5 / (15 + step) + stored_delta.mul_(decay) + stored_delta *= betas + stored_delta += delta - return -lr * (stored_delta + momentum_rate * summed_grad) + return -lr * (delta + (stored_delta * scales).sum(dim=0)) @@ -1314,7 +1297,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.01, clipping_scale=2.0, eps=1.0e-20) + optim = ScaledAdam(m.named_parameters(), lr=0.008, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From c7f5118e0cf55d22e1d623654454401b52a02846 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Mar 2025 11:15:34 +0800 Subject: [PATCH 0275/1191] increse scaling_lr_scale --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b2bb5491df..5593a14dcf 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -422,7 +422,7 @@ def __init__( clipping_scale=None, betas=(0.9, 0.98), scalar_lr_scale=0.25, - scaling_lr_scale=0.1, + scaling_lr_scale=0.2, eps=1.0e-08, weight_min_rms=0.005, bias_min_rms=1.0e-05, From 697e3d110bc29568fcd3553140ac97694a8b38a2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Mar 2025 20:50:51 +0800 Subject: [PATCH 0276/1191] Add extra debug info about delta; five times larger decay scale in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 7 ++++--- egs/librispeech/ASR/zipformer/train.py | 11 +++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 5593a14dcf..3e8bb21dbd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -299,7 +299,7 @@ def debug_step(group, p, state, grad): try: debug_info = state["debug_info"] except KeyError: - debug_info = torch.zeros(debug_buffer_size, p.shape[0], 2, + debug_info = torch.zeros(debug_buffer_size, p.shape[0], 3, device=p.device, dtype=torch.float) state["debug_info"] = debug_info @@ -318,6 +318,7 @@ def maybe_rms(x): debug_info[:, 0] = maybe_rms(p) debug_info[:, 1] = maybe_rms(grad) + debug_info[:, 2] = maybe_rms(delta) return delta @@ -347,7 +348,7 @@ def _write_debug_info(group, state, param_names, summary_writer): arange = torch.arange(debug_buffer_size) steps = debug_interval * (arange - debug_buffer_size) + cur_step - for i, legend in enumerate(['param_rms', 'grad_rms']): + for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms']): for name, info in zip(param_names, debug_info[..., i].unbind(dim=1)): debug_str = f"debug/{legend}/{name}" for step, value in zip(steps.tolist(), info.tolist()): @@ -426,7 +427,7 @@ def __init__( eps=1.0e-08, weight_min_rms=0.005, bias_min_rms=1.0e-05, - decay_scale=0.1, + decay_scale=0.5, scalar_max=10.0, size_update_period=4, clipping_update_period=100, diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 343711c2b2..2ed49b731a 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -371,6 +371,8 @@ def get_parser(): default=0, help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file + opened, seeked-in and closed for each scalar that is written). """ ) @@ -1291,14 +1293,7 @@ def run(rank, world_size, args): logging.info("Training started") if args.tensorboard and rank == 0: - # the reason for the very large max_queue is this: if --dump-debug-interval is set, - # e.g. to 2560, every that-many batches we will dump a very large number of events - # to the writer. These are added to a queue that is drained raather slowly. - # If we make the max_queue large enough to include all the events from calling - # "optimizer.write_debug_info(), we can continue with training and let the - # background thread take care of dumping those events at its own speed. - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard", - max_queue=100) + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None From 17379fab5f6e352eaf392cd0a83653038141a308 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Mar 2025 17:29:13 +0800 Subject: [PATCH 0277/1191] Implement rand_floor in ExpNorm, set it to 0.1. --- egs/librispeech/ASR/zipformer/scaling.py | 34 +++++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e2529dae72..9249bd174b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -364,9 +364,12 @@ def backward(ctx, x_grad, *args): -def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): +def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor]): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - scales = (1. - (-x_norm).exp()) / x_norm + num = (1. - (-x_norm).exp()) + if floor is not None: + num = torch.maximum(num, floor) + scales = num / x_norm scales = scale * scales return (x * scales) @@ -383,14 +386,24 @@ def forward( x: Tensor, scale: Tensor, channel_dim: int, + rand_floor: float, ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim + ctx.rand_floor = rand_floor + if rand_floor != 0.0: + shape = list(x.shape) + shape[channel_dim] = 1 + floor = torch.where(torch.rand(*shape, device=x.device) < 0.1, rand_floor, 0.0) + else: + floor = None + ctx.floor = floor + ctx.save_for_backward(x, scale) - return _exp_norm(x, scale, channel_dim) - return ans + return _exp_norm(x, scale, channel_dim, floor) + @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: @@ -399,11 +412,12 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: with torch.cuda.amp.autocast(enabled=False): x, scale = x.to(torch.float32), scale.to(torch.float32) x, scale = x.detach(), scale.detach() + x.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - ans = _exp_norm(x, scale, ctx.channel_dim) + ans = _exp_norm(x, scale, ctx.channel_dim, ctx.floor) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -411,7 +425,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(scale.grad), None + return x.grad, c(scale.grad), None, None @@ -441,16 +455,22 @@ class ExpNorm(torch.nn.Module): interpreted as an offset from the input's ndim if negative. This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. + rand_floor: if not 0.0: during training, for 10% of the vectors + we will randomly floor the numerator of the expression for the + scales (1. - (-x_norm).exp()), to this value. This is intended + to discourage the network to make the inputs smaller than this. """ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. + rand_floor: FloatLike = 0.0, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(1.7)) + self.rand_floor = rand_floor self.name = None @@ -465,7 +485,7 @@ def forward(self, x: Tensor) -> Tensor: self.scale, min=0.5, max=2.5, training=self.training) ans = ExpNormFunction.apply( - x, scale, self.channel_dim, + x, scale, self.channel_dim, float(self.rand_floor) if self.training else 0.0, ) if random.random() < 0.002: From 1748425dca584fcb35cbc4547e14d9140934acb3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Mar 2025 17:35:26 +0800 Subject: [PATCH 0278/1191] Set default rand_floor to 0.1 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9249bd174b..1550488513 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -464,7 +464,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.0, + rand_floor: FloatLike = 0.1, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From 6c6d13279068ad8f1637c0f10b355e4feacebb21 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Mar 2025 18:14:53 +0800 Subject: [PATCH 0279/1191] Increase rand_floor from .1 to .25 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1550488513..a2fc49e65b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -464,7 +464,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.1, + rand_floor: FloatLike = 0.25, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From 6b029c4a719ac019a36afcda9d3ba128037dce82 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Mar 2025 13:39:12 +0800 Subject: [PATCH 0280/1191] Double betas from 2,4,8 to 4,8,16 and increase scalar_lr_scale,scaling_lr_scale from .25,.2 to .5,.5 --- egs/librispeech/ASR/zipformer/optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3e8bb21dbd..7c72f425ed 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -169,8 +169,8 @@ def momentum_step(group, p, state, grad): state["delta"] = stored_delta else: betas = torch.tensor([ 0.96, 0.9984, 0.999936]).to(device=p.device) - # caution: the scales include the 1-beta factor. - scales = torch.tensor([ 2, 4, 8 ]).to(device=p.device) * (1-betas) + # caution: the scales will include the 1-beta factor. + scales = torch.tensor([ 4, 8, 16 ]).to(device=p.device) * (1-betas) for _ in range(p.ndim): betas, scales = betas.unsqueeze(-1), scales.unsqueeze(-1) @@ -422,8 +422,8 @@ def __init__( lr=3e-02, clipping_scale=None, betas=(0.9, 0.98), - scalar_lr_scale=0.25, - scaling_lr_scale=0.2, + scalar_lr_scale=0.5, + scaling_lr_scale=0.5, eps=1.0e-08, weight_min_rms=0.005, bias_min_rms=1.0e-05, @@ -1298,7 +1298,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.008, clipping_scale=2.0, eps=1.0e-20) + optim = ScaledAdam(m.named_parameters(), lr=0.005, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From 8df5e0f10aab3fe8bacfec5b65405c2ad1dad669 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Mar 2025 19:23:16 +0800 Subject: [PATCH 0281/1191] Major refactoring of optimizer, use transformation. --- egs/librispeech/ASR/zipformer/optim.py | 229 ++++++++++++------------- 1 file changed, 112 insertions(+), 117 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7c72f425ed..13f443ced4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -125,12 +125,13 @@ def batched_params(self, param_group, group_params_names): def basic_step(group, p, state, grad): # computes basic Adam normalized-grad using beta2 (dividing by gradient stddev) only. no momentum yet. - beta2 = group["betas"][1] + beta2 = group["beta2"] eps = group["eps"] # p shape: (batch_size,) or (batch_size, 1, [1,..]) try: exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) or (batch_size, 1, [1,..]) except KeyError: + assert state["step"] < 2 exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) state["exp_avg_sq"] = exp_avg_sq @@ -160,24 +161,26 @@ def momentum_step(group, p, state, grad): try: stored_delta = state["delta"] if p.numel() != p.shape[0]: - scales = state["scales"] + alphas = state["alphas"] betas = state["betas"] - except KeyError: + except KeyError as e: + assert step < 2 if p.numel() == p.shape[0]: # scalar. use conventional momentum. stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) state["delta"] = stored_delta else: - betas = torch.tensor([ 0.96, 0.9984, 0.999936]).to(device=p.device) - # caution: the scales will include the 1-beta factor. - scales = torch.tensor([ 4, 8, 16 ]).to(device=p.device) * (1-betas) + # e.g. group["betas"] = (.96, .9984, .99936), + # group["scales"] = (4., 8., 16.), + betas = torch.tensor(group["betas"]).to(device=p.device) + alphas = torch.tensor(group["scales"]).to(device=p.device) * (1-betas) for _ in range(p.ndim): - betas, scales = betas.unsqueeze(-1), scales.unsqueeze(-1) + betas, alphas = betas.unsqueeze(-1), alphas.unsqueeze(-1) stored_delta = torch.zeros(len(betas), *p.shape, device=p.device, dtype=torch.float) state["delta"] = stored_delta state["betas"] = betas - state["scales"] = scales + state["alphas"] = alphas if p.numel() == p.shape[0]: @@ -195,95 +198,69 @@ def momentum_step(group, p, state, grad): stored_delta *= betas stored_delta += delta - return -lr * (delta + (stored_delta * scales).sum(dim=0)) - - - + return -lr * (delta + (stored_delta * alphas).sum(dim=0)) -def scaling_step(group, p, state, grad): - delta = momentum_step(group, p, state, grad) - if p.numel() == p.shape[0]: - return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) - - step = state["step"] - size_update_period = group["size_update_period"] - - try: - param_rms = state["param_rms"] - scale_grads = state["scale_grads"] - scale_exp_avg_sq = state["scale_exp_avg_sq"] - except KeyError: - # we know p.ndim > 1 because we'd have returned above if not, so don't worry - # about the speial case of dim=[] that pytorch treats inconsistently. - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - param_rms = param_rms.to(torch.float) - scale_exp_avg_sq = torch.zeros_like(param_rms) - scale_grads = torch.zeros(size_update_period, *param_rms.shape, - dtype=torch.float, device=p.device) - state["param_rms"] = param_rms - state["scale_grads"] = scale_grads - state["scale_exp_avg_sq"] = scale_exp_avg_sq - - - # on every step, update the gradient w.r.t. the scale of the parameter, we - # store these as a batch and periodically update the size (for speed only, to - # avoid too many operations). - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - - # periodically recompute the value of param_rms. - if step % size_update_period == size_update_period - 1: - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) +def forward_transform_param(group, p): + """ + Returns a transformed version of the batch of parameters (dimension 0 of p is the batch + of same-shaped parameters). + The transformation is from a parameter to a (parameter-direction, log-weight), where + parameter-direction has unit RMS value and log-weight + """ + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel == 1: + # scalar parameters are treated specially. scalar_lr_scale is to control + # the learning-rate of scalars. + return p.reshape(batch_size, 1) / group["scalar_lr_scale"] - # would be p.ndim > 1 not p.ndim > 2 but one dim is for batch of tensors. is_weight = (p.ndim > 2) min_rms = group["weight_min_rms"] if is_weight else group["bias_min_rms"] + p = p.reshape(batch_size, numel) + sumsq = (p ** 2).sum(dim=1, keepdim=True) + min_sumsq = (min_rms ** 2) * numel # if sumsq is less than this we pad with an extra element. + sumsq_clamped = sumsq.clamp(min=min_sumsq) + pad = (sumsq_clamped - sumsq).sqrt() + scale = (sumsq_clamped / numel).sqrt() # must be nonzero thanks to min_rms + + # scaling_lr_scale is to control the learning-rate of scaling factors. + log_scale = (1 / group["scaling_lr_scale"]) * scale.log() + return torch.cat((p / scale, pad / scale, log_scale), dim=1) + +def reverse_transform_param(group, p, orig_shape): + batch_size = p.shape[0] + if p.numel() == batch_size: + return (p * group["scalar_lr_scale"]).reshape(*orig_shape) + numel = p.shape[1] - 2 # numel of original shape + + p_padded = p[:, :-1] + p_padded = p_padded / ((p_padded ** 2).sum(dim=1, keepdim=True) / numel).sqrt() # normalize rms to 1. + scale = (p[:, -1:] * group["scaling_lr_scale"]).exp() + p = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. + return p.reshape(*orig_shape) + + +def forward_transform_param_and_grad(group, p, grad): + # returns new parameter. + p_shape = p.shape + p_flat = forward_transform_param(group, p).detach() + with torch.enable_grad(): + p_flat.requires_grad = True + p_reconstruct = reverse_transform_param(group, p_flat, p.shape) + p_reconstruct.backward(gradient=grad) + return p_flat.detach(), p_flat.grad - # scale the step size by param_rms. This is the most important "scaling" part of - # ScaledAdam - delta *= param_rms.clamp(min=min_rms) - - if step % size_update_period == size_update_period - 1 and step > 0: - # This block updates the size of parameter by adding a step ("delta") value in - # the direction of either shrinking or growing it. it also includes a modified - # form of adamw-like shrinkage, which we modify a bit to ensure there is a unique - # optimum for the scales (since thanks to the "delta *= param_rms.clamp(min=min_rms)", - # there is no longer the size-stabilizing phenomenon as in Adam whereby parameters with smaller - # rms will tend to grow faster thanks to parameter noise). - beta2 = group["betas"][1] - size_lr = group["lr"] * group["scaling_lr_scale"] - eps = group["eps"] - decay_scale = group["decay_scale"] - batch_size = p.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_norm_grad = ( - (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - # add AdamW-like decay, if we are above "penalty_rms" which is not a hard - # maximum but a cutoff for introducing decay. - scale_norm_grad = scale_norm_grad + (decay_scale * size_update_period) * param_rms - scale_step = -size_lr * scale_norm_grad - delta.add_(p * scale_step) +def scaling_step(group, p, state, grad): + # returns new parameter. + p_shape = p.shape + p_flat, grad_flat = forward_transform_param_and_grad(group, p, grad) + + p_flat += momentum_step(group, p_flat, state, grad_flat) - return delta + p = reverse_transform_param(group, p_flat, p.shape) + return p def debug_step(group, p, state, grad): @@ -291,15 +268,15 @@ def debug_step(group, p, state, grad): debug_buffer_size = 256 step = state["step"] - delta = scaling_step(group, p, state, grad) + p = scaling_step(group, p, state, grad) if debug_interval == 0 or step % debug_interval != 0: - return delta + return p try: debug_info = state["debug_info"] except KeyError: - debug_info = torch.zeros(debug_buffer_size, p.shape[0], 3, + debug_info = torch.zeros(debug_buffer_size, p.shape[0], 2, device=p.device, dtype=torch.float) state["debug_info"] = debug_info @@ -318,9 +295,8 @@ def maybe_rms(x): debug_info[:, 0] = maybe_rms(p) debug_info[:, 1] = maybe_rms(grad) - debug_info[:, 2] = maybe_rms(delta) - return delta + return p def _write_debug_info(group, state, param_names, summary_writer): @@ -348,7 +324,7 @@ def _write_debug_info(group, state, param_names, summary_writer): arange = torch.arange(debug_buffer_size) steps = debug_interval * (arange - debug_buffer_size) + cur_step - for i, legend in enumerate(['param_rms', 'grad_rms', 'delta_rms']): + for i, legend in enumerate(['param_rms', 'grad_rms']): for name, info in zip(param_names, debug_info[..., i].unbind(dim=1)): debug_str = f"debug/{legend}/{name}" for step, value in zip(steps.tolist(), info.tolist()): @@ -390,8 +366,12 @@ class ScaledAdam(BatchedOptimizer): we mean after multiplying by the rms parameter value for this tensor [for non-scalars]; this is appropriate because our update is scaled by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. Must satisfy 0 < beta <= beta2 < 1. + betas: a list of decay constants for momentum on the parameter-change + scales: a list of scales corresponding to each of the betas, that we multiply + each momentum-update by. Implicitly there is also a beta=0, scale=1, + i.e. a non-decayed update. scaling_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each non-scalar parameter tensor. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale @@ -421,8 +401,10 @@ def __init__( params, lr=3e-02, clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.5, + beta2=0.98, + betas=(.96, .9984, .99936), + scales=(4., 8., 16.), + scalar_lr_scale=0.2, scaling_lr_scale=0.5, eps=1.0e-08, weight_min_rms=0.005, @@ -437,7 +419,9 @@ def __init__( defaults = dict( lr=lr, clipping_scale=clipping_scale, + beta2=beta2, betas=betas, + scales=scales, scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, @@ -621,7 +605,7 @@ def step(self, closure=None): cur_step = 0 grad = (p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)) - p += debug_step(group, p.detach(), state, grad) + p[:] = debug_step(group, p.detach(), state, grad) if p.numel() == p.shape[0]: # scalar parameter scalar_max = group["scalar_max"] @@ -694,7 +678,8 @@ def _get_clipping_scale( scalar_lr_scale**2 ) # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + param_meansq = (p ** 2).mean(dim=tuple(range(1, p.ndim)), keepdim=True) + tot_sumsq += ((grad ** 2) * param_meansq).sum() tot_norm = tot_sumsq.sqrt() if "model_norms" not in first_state: @@ -759,7 +744,7 @@ def _get_clipping_scale( self._show_gradient_dominating_parameter( tuples, tot_sumsq, group["scalar_lr_scale"] ) - self._show_param_with_unusual_grad(tuples) + self._show_param_with_unusual_grad(group, tuples) if ans == 0.0: for (p, state, param_names) in tuples: @@ -768,8 +753,9 @@ def _get_clipping_scale( return ans def _show_param_with_unusual_grad( - self, - tuples: List[Tuple[Tensor, dict, List[str]]], + self, + group, + tuples: List[Tuple[Tensor, dict, List[str]]], ): """ Print information about parameter which has the largest ratio of grad-on-this-batch @@ -786,13 +772,10 @@ def _show_param_with_unusual_grad( ratios_names = [ ] for (p, state, batch_param_names) in tuples: dims = list(range(1, p.ndim)) - def mean(x): - # workaround for bad interface of torch's "mean" for when dims is the empty list. - if len(dims) > 0: - return x.mean(dim=dims) - else: - return x - grad_ratio = (mean(p.grad ** 2) / state["exp_avg_sq"].mean(dim=dims)).sqrt() + + p_flat, grad_flat = forward_transform_param_and_grad(group, p, p.grad) + + grad_ratio = ((grad_flat ** 2).mean(dim=1) / state["exp_avg_sq"].mean(dim=1)).sqrt() ratios_names += zip(grad_ratio.to('cpu').tolist(), batch_param_names) ratios_names = sorted(ratios_names, reverse=True) @@ -828,22 +811,24 @@ def _show_gradient_dominating_parameter( batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars # Dummy values used by following `zip` statement. - batch_rms_orig = torch.full( - p.shape, scalar_lr_scale, device=batch_grad.device + batch_meansq = torch.full( + p.shape, scalar_lr_scale ** 2, device=batch_grad.device ) else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 + batch_meansq = (p ** 2).mean(dim=tuple(range(1, p.ndim)), keepdim=True) + + batch_rms = batch_meansq.sqrt() # rms of each parameter. + batch_sumsq = (batch_grad * batch_rms) ** 2 # sum-square of grad times param rms + if batch_grad.ndim > 1: # need to guard it with if-statement because sum() sums over # all dims if dim == (). - batch_sumsq_orig = batch_sumsq_orig.sum( + batch_sumsq = batch_sumsq.sum( dim=list(range(1, batch_grad.ndim)) ) for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + batch_param_names, batch_sumsq, batch_rms, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1351,6 +1336,15 @@ def _test_scaled_adam(hidden_dim: int): logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") +def _test_transform_params(): + group = { "bias_min_rms": 0.001, "weight_min_rms": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5 } + for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: + for shape in [ (1, 1), (2, 1), (2, 2), (2, 3, 4), (3, 10, 20), (4,) ]: + p = scale * torch.randn(*shape) + q = forward_transform_param(group, p) + r = reverse_transform_param(group, q, p.shape) + assert torch.allclose(p, r), (p, q, r) + if __name__ == "__main__": torch.set_num_threads(1) @@ -1369,5 +1363,6 @@ def _test_scaled_adam(hidden_dim: int): else: hidden_dim = 200 + _test_transform_params() _test_scaled_adam(hidden_dim) _test_eden() From e1ff097a7f91e8cdda4f5bda1d34d70eb06c2b63 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Mar 2025 21:48:46 +0800 Subject: [PATCH 0282/1191] Add optimization w.r.t. scaling factors. --- egs/librispeech/ASR/zipformer/optim.py | 80 ++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 13f443ced4..76288b7854 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -201,6 +201,33 @@ def momentum_step(group, p, state, grad): return -lr * (delta + (stored_delta * alphas).sum(dim=0)) +def get_scaling_shapes(shape): + """ shape is a list representing a shape of a batch of tensors, + interpreted as (batch_size, a, b, ..). We return a list of + shapes of tensors to add to the "expanded representation" of the + tensor, that will be interpreted as (offsets to) scales on + various dimensions. + """ + num_nontrivial = sum ([ 1 if x > 1 else 0 for x in shape[1:] ]) + ans = [ ] + if num_nontrivial <= 1: + # there are no 'scaling shapes' as the tensor has less than two + # nontrivial dims. + return ans + for i in range(1, len(shape)): + if shape[i] != 1: + l = list(shape) + l[i] = 1 + ans.append(l) + return ans + +def prod_of_list(seq): + prod = 1 + for i in seq: + prod = prod * i + return prod + + def forward_transform_param(group, p): """ Returns a transformed version of the batch of parameters (dimension 0 of p is the batch @@ -217,28 +244,56 @@ def forward_transform_param(group, p): is_weight = (p.ndim > 2) min_rms = group["weight_min_rms"] if is_weight else group["bias_min_rms"] - p = p.reshape(batch_size, numel) - sumsq = (p ** 2).sum(dim=1, keepdim=True) + p_flat = p.reshape(batch_size, numel) + sumsq = (p_flat ** 2).sum(dim=1, keepdim=True) min_sumsq = (min_rms ** 2) * numel # if sumsq is less than this we pad with an extra element. sumsq_clamped = sumsq.clamp(min=min_sumsq) pad = (sumsq_clamped - sumsq).sqrt() scale = (sumsq_clamped / numel).sqrt() # must be nonzero thanks to min_rms # scaling_lr_scale is to control the learning-rate of scaling factors. + # log_scale controls the overall scale of this tensor log_scale = (1 / group["scaling_lr_scale"]) * scale.log() - return torch.cat((p / scale, pad / scale, log_scale), dim=1) + + # We also include scaling factors that will scale individual rows and columns of the + # weights. These are initially all zero, we'll scale by (1 + coeff * this_scaling_factor) + + scaling_dim = sum([ prod_of_list(l[1:]) for l in get_scaling_shapes(p.shape) ]) + scaling_factors = torch.zeros(batch_size, scaling_dim, device=p.device, dtype=p.dtype) + + ans = torch.cat((p_flat / scale, pad / scale, log_scale, scaling_factors), dim=1) + return ans def reverse_transform_param(group, p, orig_shape): batch_size = p.shape[0] if p.numel() == batch_size: return (p * group["scalar_lr_scale"]).reshape(*orig_shape) - numel = p.shape[1] - 2 # numel of original shape - - p_padded = p[:, :-1] + # numel is num elements of each parameter tensor in the batch. + numel = prod_of_list(orig_shape[1:]) + p_padded = p[:, :numel+1] # orig tensor plus one padding element p_padded = p_padded / ((p_padded ** 2).sum(dim=1, keepdim=True) / numel).sqrt() # normalize rms to 1. - scale = (p[:, -1:] * group["scaling_lr_scale"]).exp() - p = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. - return p.reshape(*orig_shape) + + is_weight = (len(orig_shape) > 2) + max_rms = group["weight_max_rms"] if is_weight else group["bias_max_rms"] + scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]).exp().clamp(max=max_rms) + + q = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. + q = q.reshape(*orig_shape) + # Now include the scaling factors. these were originally all zero as returned from + # forward_transform_param. + offset = numel + 2 # + 1 for the padding element and the log-scale. + + S = group["scaling_lr_scale"] + shapes = get_scaling_shapes(orig_shape) + num_shapes = len(shapes) + for scaling_shape in shapes: + this_numel = prod_of_list(scaling_shape[1:]) + assert offset + this_numel <= p.shape[1] + scales = p[:, offset:offset+this_numel].reshape(*scaling_shape) + offset = offset + this_numel + scales = 1.0 + (S / num_shapes) * scales + q = q * scales + return q def forward_transform_param_and_grad(group, p, grad): @@ -408,7 +463,9 @@ def __init__( scaling_lr_scale=0.5, eps=1.0e-08, weight_min_rms=0.005, + weight_max_rms=1.0, bias_min_rms=1.0e-05, + bias_max_rms=5.0, decay_scale=0.5, scalar_max=10.0, size_update_period=4, @@ -426,7 +483,9 @@ def __init__( scaling_lr_scale=scaling_lr_scale, eps=eps, weight_min_rms=weight_min_rms, + bias_max_rms=bias_max_rms, bias_min_rms=bias_min_rms, + weight_max_rms=weight_max_rms, decay_scale=decay_scale, scalar_max=scalar_max, size_update_period=size_update_period, @@ -1337,7 +1396,8 @@ def _test_scaled_adam(hidden_dim: int): logging.info(f"output_magnitudes = {output_magnitudes}") def _test_transform_params(): - group = { "bias_min_rms": 0.001, "weight_min_rms": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5 } + group = { "bias_min_rms": 0.001, "weight_min_rms": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, + "weight_max_rms": 20.0, "bias_max_rms": 20.0 } for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: for shape in [ (1, 1), (2, 1), (2, 2), (2, 3, 4), (3, 10, 20), (4,) ]: p = scale * torch.randn(*shape) From 19fefa502127761f3f73c32181735e77816d4324 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Mar 2025 16:42:30 +0800 Subject: [PATCH 0283/1191] Update get_parameter_groups_with_lrs(), set feedforwardX.in_proj.weight_min_rms=0.02, enforce min_rms in reverse_transform_param. --- egs/librispeech/ASR/zipformer/optim.py | 37 ++++++++-------- egs/librispeech/ASR/zipformer/train.py | 4 +- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++ icefall/utils.py | 51 +++++++++++++++------- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 76288b7854..3a01cfac27 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -275,7 +275,8 @@ def reverse_transform_param(group, p, orig_shape): is_weight = (len(orig_shape) > 2) max_rms = group["weight_max_rms"] if is_weight else group["bias_max_rms"] - scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]).exp().clamp(max=max_rms) + min_rms = group["weight_min_rms"] if is_weight else group["bias_min_rms"] + scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]).exp().clamp(min=min_rms, max=max_rms) q = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. q = q.reshape(*orig_shape) @@ -397,7 +398,7 @@ def _load_state_dict_pre_hook(optim: Optimizer, state_dict: dict): except KeyError: logging.info(f"Could not copy key {key} from optim state-dict.") -class ScaledAdam(BatchedOptimizer): +class TransformedAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update proportional to the norm of that parameter; and also learn the scale of the parameter, @@ -458,9 +459,9 @@ def __init__( clipping_scale=None, beta2=0.98, betas=(.96, .9984, .99936), - scales=(4., 8., 16.), + scales=(4., 8., 10.), scalar_lr_scale=0.2, - scaling_lr_scale=0.5, + scaling_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, weight_max_rms=1.0, @@ -498,7 +499,7 @@ def __init__( # this flag will be set to False in funciton _get_names_of_parameters. self.show_dominant_parameters = True param_groups, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(param_groups, defaults) + super(TransformedAdam, self).__init__(param_groups, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names @@ -512,10 +513,10 @@ def _get_names_of_parameters( ) -> Tuple[List[Dict], List[List[str]]]: """ Args: - params_or_named_params: according to the way ScaledAdam is initialized in train.py, + params_or_named_params: according to the way TransformedAdam is initialized in train.py, this argument could be one of following 4 cases, case 1, a generator of parameter, e.g.: - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + optimizer = TransformedAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) case 2, a list of parameter groups with different config, e.g.: model_param_groups = [ @@ -523,10 +524,10 @@ def _get_names_of_parameters( {'params': model.decoder.parameters(), 'lr': 0.01}, {'params': model.joiner.parameters(), 'lr': 0.03}, ] - optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + optimizer = TransformedAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) case 3, a generator of named_parameter, e.g.: - optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + optimizer = TransformedAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) case 4, a list of named_parameter groups with different config, e.g.: model_named_param_groups = [ @@ -534,7 +535,7 @@ def _get_names_of_parameters( {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, ] - optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + optimizer = TransformedAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. For case 3 and case 4, firstly, names and params are extracted from input named_params, @@ -615,7 +616,7 @@ def _get_names_of_parameters( return param_groups, param_groups_names def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) + super(TransformedAdam, self).__setstate__(state) @torch.no_grad() def step(self, closure=None): @@ -653,7 +654,7 @@ def step(self, closure=None): grad = p.grad if grad.is_sparse: raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" + "TransformedAdam optimizer does not support sparse gradients" ) @@ -730,7 +731,7 @@ def _get_clipping_scale( grad = p.grad if grad.is_sparse: raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" + "TransformedAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() * ( @@ -1031,7 +1032,7 @@ class Eden(LRScheduler): of an entire training run, but it doesn't matter much. You could also use Eden2 which has only the notion of batches. - We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + We suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam Args: optimizer: the optimizer to change the learning rates on @@ -1089,7 +1090,7 @@ class Eden2(LRScheduler): and then stays constant at 1. - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam Args: optimizer: the optimizer to change the learning rates on @@ -1129,7 +1130,7 @@ def get_lr(self): def _test_eden(): m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) + optim = TransformedAdam(m.parameters(), lr=0.03) scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) @@ -1152,7 +1153,7 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") -# This is included mostly as a baseline for ScaledAdam. +# This is included mostly as a baseline for TransformedAdam. class Eve(Optimizer): """ Implements Eve algorithm. This is a modified version of AdamW with a special @@ -1342,7 +1343,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.005, clipping_scale=2.0, eps=1.0e-20) + optim = TransformedAdam(m.named_parameters(), lr=0.005, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 2ed49b731a..2a29a2377f 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel -from optim import Eden, ScaledAdam +from optim import Eden, TransformedAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -1365,7 +1365,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( + optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2e33c30b0e..55a02bec6f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1621,6 +1621,9 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 # shared_dim=0 means we share the dropout mask along the time axis self.out_proj = ActivationDropoutAndLinear( diff --git a/icefall/utils.py b/icefall/utils.py index 41eebadd46..de4d9b2127 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1499,20 +1499,27 @@ def get_parameter_groups_with_lrs( """ named_modules = list(model.named_modules()) - # flat_lr_scale just contains the lr_scale explicitly specified - # for each prefix of the name, e.g. 'encoder.layers.3', these need - # to be multiplied for all prefix of the name of any given parameter. - flat_lr_scale = defaultdict(lambda: 1.0) + # flat_lr_scale[prefix] for a prefix like 'encoder.layers.3' contains + # a dict with all the optimizer configuration settings specified at this level. + # these need to be combined for all prefixes of the name of any given parameter. + flat_config = defaultdict(dict) names = [] for name, m in model.named_modules(): names.append(name) - if hasattr(m, "lr_scale"): - flat_lr_scale[name] = m.lr_scale + for attr in ['lr_scale', 'weight_min_rms', 'bias_min_rms', 'weight_max_rms', 'bias_max_rms']: # we can add more here as needed + try: + # getattr(m, attr) if attr == 'lr_scale' is equivalent to m.lr_scale + flat_config[name][attr] = getattr(m, attr) + except AttributeError: + pass - # lr_to_parames is a dict from learning rate (floating point) to: if + + # lr_to_parames is a dict from config-string to: # include_names == true, a list of (name, parameter) for that learning rate; # otherwise a list of parameters for that learning rate. - lr_to_params = defaultdict(list) + # The config-string is the repr(dict) for the dictionary of attributes combined + # over all prefixes of that parameter name. + config_to_params = defaultdict(list) for name, parameter in model.named_parameters(): split_name = name.split(".") @@ -1527,18 +1534,30 @@ def get_parameter_groups_with_lrs( if prefix in freeze_modules: logging.info(f"Remove {name} from parameters") continue - cur_lr = lr * flat_lr_scale[prefix] + + cur_config = dict() + cur_config.update(flat_config[prefix]) # include dict items from here. if prefix != "": - cur_lr *= flat_lr_scale[""] + cur_config.update(flat_config[""]) for part in split_name[1:]: prefix = ".".join([prefix, part]) - cur_lr *= flat_lr_scale[prefix] - lr_to_params[cur_lr].append((name, parameter) if include_names else parameter) + cur_config.update(flat_config[prefix]) - if include_names: - return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()] - else: - return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()] + + config_to_params[repr(cur_config)].append((name, parameter) if include_names else parameter) + + + ans = [ ] + for config, params in config_to_params.items(): + config = eval(config) # turn from string back into dict. + try: # turn "lr_scale" into "lr" + config["lr"] = lr * config["lr_scale"] + del config["lr_scale"] + except KeyError: + pass + config["named_params" if include_names else "params"] = params + ans.append(config) + return ans def optim_step_and_measure_param_change( From 16d0ca9c62ee5cec4c04509aa169b1a302b4d9e9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Mar 2025 16:44:46 +0800 Subject: [PATCH 0284/1191] Let attrs be user-specified in utils.py and change the name of ScaledAdam to TransformedAdam --- icefall/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index de4d9b2127..83e8106322 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1474,12 +1474,16 @@ def get_parameter_groups_with_lrs( lr: float, include_names: bool = False, freeze_modules: List[str] = [], + attrs: List[str] = ['lr_scale', 'weight_min_rms', 'bias_min_rms', 'weight_max_rms', 'bias_max_rms'], ) -> List[dict]: """ - This is for use with the ScaledAdam optimizers (more recent versions that accept lists of - named-parameters; we can, if needed, create a version without the names). + This is to automatically create parameter-groups with overrides of parameter optimizer + settings, especially the learning rate which can be scaled using the "lr_scale" attribut + in modules, but also other possible configuration values that you may specify. - It provides a way to specify learning-rate scales inside the module, so that if + + It provides a way to specify learning-rate scales and other optimizer configuration + settings inside the module, so that if any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will scale the LR of any parameters inside that module or its submodules. Note: you can set module parameters outside the __init__ function, e.g.: @@ -1506,7 +1510,7 @@ def get_parameter_groups_with_lrs( names = [] for name, m in model.named_modules(): names.append(name) - for attr in ['lr_scale', 'weight_min_rms', 'bias_min_rms', 'weight_max_rms', 'bias_max_rms']: # we can add more here as needed + for attr in attrs: # we can add more here as needed try: # getattr(m, attr) if attr == 'lr_scale' is equivalent to m.lr_scale flat_config[name][attr] = getattr(m, attr) From c2ef57517828dcdd1518f1a0e3f2b6f1dce8a603 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Mar 2025 13:07:03 +0800 Subject: [PATCH 0285/1191] Do not learn scaling factors for rows and columns. --- egs/librispeech/ASR/zipformer/optim.py | 46 ++------------------------ 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3a01cfac27..887f8c6af4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -201,32 +201,6 @@ def momentum_step(group, p, state, grad): return -lr * (delta + (stored_delta * alphas).sum(dim=0)) -def get_scaling_shapes(shape): - """ shape is a list representing a shape of a batch of tensors, - interpreted as (batch_size, a, b, ..). We return a list of - shapes of tensors to add to the "expanded representation" of the - tensor, that will be interpreted as (offsets to) scales on - various dimensions. - """ - num_nontrivial = sum ([ 1 if x > 1 else 0 for x in shape[1:] ]) - ans = [ ] - if num_nontrivial <= 1: - # there are no 'scaling shapes' as the tensor has less than two - # nontrivial dims. - return ans - for i in range(1, len(shape)): - if shape[i] != 1: - l = list(shape) - l[i] = 1 - ans.append(l) - return ans - -def prod_of_list(seq): - prod = 1 - for i in seq: - prod = prod * i - return prod - def forward_transform_param(group, p): """ @@ -255,13 +229,7 @@ def forward_transform_param(group, p): # log_scale controls the overall scale of this tensor log_scale = (1 / group["scaling_lr_scale"]) * scale.log() - # We also include scaling factors that will scale individual rows and columns of the - # weights. These are initially all zero, we'll scale by (1 + coeff * this_scaling_factor) - - scaling_dim = sum([ prod_of_list(l[1:]) for l in get_scaling_shapes(p.shape) ]) - scaling_factors = torch.zeros(batch_size, scaling_dim, device=p.device, dtype=p.dtype) - - ans = torch.cat((p_flat / scale, pad / scale, log_scale, scaling_factors), dim=1) + ans = torch.cat((p_flat / scale, pad / scale, log_scale), dim=1) return ans def reverse_transform_param(group, p, orig_shape): @@ -269,7 +237,7 @@ def reverse_transform_param(group, p, orig_shape): if p.numel() == batch_size: return (p * group["scalar_lr_scale"]).reshape(*orig_shape) # numel is num elements of each parameter tensor in the batch. - numel = prod_of_list(orig_shape[1:]) + numel = p.shape[1] - 2 p_padded = p[:, :numel+1] # orig tensor plus one padding element p_padded = p_padded / ((p_padded ** 2).sum(dim=1, keepdim=True) / numel).sqrt() # normalize rms to 1. @@ -284,16 +252,6 @@ def reverse_transform_param(group, p, orig_shape): # forward_transform_param. offset = numel + 2 # + 1 for the padding element and the log-scale. - S = group["scaling_lr_scale"] - shapes = get_scaling_shapes(orig_shape) - num_shapes = len(shapes) - for scaling_shape in shapes: - this_numel = prod_of_list(scaling_shape[1:]) - assert offset + this_numel <= p.shape[1] - scales = p[:, offset:offset+this_numel].reshape(*scaling_shape) - offset = offset + this_numel - scales = 1.0 + (S / num_shapes) * scales - q = q * scales return q From b3121b245eb7e50be1638b5e822f4c03ab6071dd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Mar 2025 10:53:51 +0800 Subject: [PATCH 0286/1191] Implement max-beta formula dependent on LR. max_change=0.1. Debugging code enabled; must remove this. --- egs/librispeech/ASR/zipformer/optim.py | 64 ++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 887f8c6af4..0d91c4e53e 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -157,12 +157,15 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] + debug = True try: stored_delta = state["delta"] if p.numel() != p.shape[0]: alphas = state["alphas"] betas = state["betas"] + if debug: + decayed_params = state["decayed_params"] except KeyError as e: assert step < 2 if p.numel() == p.shape[0]: @@ -181,7 +184,10 @@ def momentum_step(group, p, state, grad): state["delta"] = stored_delta state["betas"] = betas state["alphas"] = alphas - + if debug: + decayed_params = torch.zeros_like(stored_delta) + decayed_params[:] = p + state["decayed_params"] = decayed_params if p.numel() == p.shape[0]: # scalar. use conventional momentum. @@ -190,11 +196,61 @@ def momentum_step(group, p, state, grad): lr = lr * group["scalar_lr_scale"] return -lr * stored_delta + max_change = 0.1 + if max_change > 0.0: + # Limit the beta values in an LR-dependent way, that doesn't allow us to + # hang onto "momentum" past the point when we expect the parameter to have + # significantly changed by more than "max_change" of relative change. + # the formulas are a bit approximate but we can just tune max_change to + # compensate. OK, so we assume that the update formula will be as follows: + # stored_delta *= betas + # stored_delta += delta + # return -lr * (delta + (stored_delta * alphas).sum(dim=0)) + # and we limit the betas so that, considering only the part of the change + # from "this beta value" (i.e. from this accumulator, ignoring the other betas and + # the implicit "beta=0,alpha=1") we don't keep changes from gradients that + # were accumulated when the parameter was too significantly different. + # For this we assume gradient independence between steps, i.e. independence between the + # delta values returned by basic_step. So the "effective LR" for each beta-accumulator, + # if we count the whole decay sequence, will be: + # lr_eff = lr * alpha * 1/(1-beta) + # .. and the "extra variance" that this adds to the parameter value is lr_eff^2. + # The parameter values have been normalized by forward_transform_param so that + # their variance is about 1, so this "extra variance" can be interpreted as a + # "relative change in variance", and the corresponding "relative change in parameter", + # interpreted in a cosine-of-angle sense + # would be sqrt(relative-change-in-variance) ~ 0.5 (relative-change-in-variance) + # [we shrink the parameter down by 1/sqrt(relative-change-in-variance) to renormalize + # it, and we can assume the gradient + # was orthogonal to the parameter on average.] + # Now, the time period over which we have to measure parameter changes is about + # 1/(1-beta); this is the "decay time" of the accumulator. So the limiting-to-max-change + # equation becomes: + # 1/(1-beta). 0.5 lr_eff^2 <= max_change + # and expanding: + # 1/(1-beta). 0.5 (lr * alpha * 1/(1-beta))^2 <= max_change + # (lr * alpha)^2 <= max_change (1-beta)^3 + # (1-beta) >= ((0.5/max_change) * (lr * alpha)**2) ** 1/3. + # beta <= 1 - ((0.5/max_change) * (lr * alpha)**2) ** 1/3. + # so: + ceil = 1. - (((0.5 / max_change) * lr**2) * (alphas**2)) ** 0.3333 + if random.random() < 0.002 and not debug: + logging.info(f"lr={lr}, shape={list(p.shape)}, ceil={ceil.flatten()}, betas={betas.flatten()}") + betas = torch.minimum(betas, ceil) + + + if debug: + decayed_params *= betas + decayed_params += (1-betas) * p + if random.random() < 0.001: + dims = tuple(range(1, decayed_params.ndim)) + cosine = ((p * decayed_params).sum(dim=dims) / + ((p*p).sum() * (decayed_params*decayed_params).sum(dim=dims)).sqrt()) + param_change = 1 - cosine + logging.info(f"lr={lr}, shape={list(p.shape)}, ceil={ceil.flatten()}, betas={betas.flatten()}, max_change={max_change}, change={param_change}") + - # an extra decay of the deltas near the beginning of training, as early grads may change fast. - decay = 1. - 1.5 / (15 + step) - stored_delta.mul_(decay) stored_delta *= betas stored_delta += delta From 86727f876186b51d2c922f9ac2fda4300d62305f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Mar 2025 14:29:04 +0800 Subject: [PATCH 0287/1191] Do not take into account log-scale in computing param distance for loggin --- egs/librispeech/ASR/zipformer/optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 0d91c4e53e..57bccdd2f1 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -244,6 +244,8 @@ def momentum_step(group, p, state, grad): decayed_params += (1-betas) * p if random.random() < 0.001: dims = tuple(range(1, decayed_params.ndim)) + p = p[:, :-2] # get rid of the padding element and the log-scale parameter (which may be too large) + decayed_params = decayed_params[:, :, :-2] cosine = ((p * decayed_params).sum(dim=dims) / ((p*p).sum() * (decayed_params*decayed_params).sum(dim=dims)).sqrt()) param_change = 1 - cosine From 637fe9a5bff711910179678c4451f95364569f04 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Mar 2025 16:39:52 +0800 Subject: [PATCH 0288/1191] Change betas from (.96, .9984, .99936) to betas=(.95, .9975, .999875) and scales from 4,8,10 to 4,8,16. --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 57bccdd2f1..5ecf48945e 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -474,8 +474,8 @@ def __init__( lr=3e-02, clipping_scale=None, beta2=0.98, - betas=(.96, .9984, .99936), - scales=(4., 8., 10.), + betas=(.95, .9975, .999875), + scales=(4., 8., 16.), scalar_lr_scale=0.2, scaling_lr_scale=0.1, eps=1.0e-08, From b2cf843f368160105dc4bcd3d6c568f73f094863 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Mar 2025 16:57:30 +0800 Subject: [PATCH 0289/1191] Set debug to False in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 5ecf48945e..bdd05c4d70 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -157,7 +157,7 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] - debug = True + debug = False try: stored_delta = state["delta"] From f47aba66485935dfac1cdca24f3569a9b738fd9d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 27 Mar 2025 13:34:23 +0800 Subject: [PATCH 0290/1191] Reduce rand_floor in Expnorm from .25 to 0. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a2fc49e65b..9249bd174b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -464,7 +464,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.25, + rand_floor: FloatLike = 0.0, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From 7dd5f87f092b413e031427d28bfc930789efd66c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 27 Mar 2025 15:52:50 +0800 Subject: [PATCH 0291/1191] Cosmetic improvements in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index bdd05c4d70..9166b9ff59 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -483,13 +483,12 @@ def __init__( weight_max_rms=1.0, bias_min_rms=1.0e-05, bias_max_rms=5.0, - decay_scale=0.5, - scalar_max=10.0, size_update_period=4, clipping_update_period=100, debug_interval=0, ): + assert len(betas) == len(scales) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -503,8 +502,6 @@ def __init__( bias_max_rms=bias_max_rms, bias_min_rms=bias_min_rms, weight_max_rms=weight_max_rms, - decay_scale=decay_scale, - scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, debug_interval=debug_interval, @@ -683,10 +680,6 @@ def step(self, closure=None): grad = (p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)) p[:] = debug_step(group, p.detach(), state, grad) - if p.numel() == p.shape[0]: # scalar parameter - scalar_max = group["scalar_max"] - p.clamp_(min=-scalar_max, max=scalar_max) - state["step"] = cur_step + 1 From 76fb658e70da8830c9d58c14656cc30a798ba62d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 28 Mar 2025 15:59:12 +0800 Subject: [PATCH 0292/1191] Restore rand_floor in ExpNorm from 0.0 to 0.25, reversing 430->432. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9249bd174b..a2fc49e65b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -464,7 +464,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.0, + rand_floor: FloatLike = 0.25, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From 8cc82b313c1e301b4f8f4fdcd72a6b0dd044ab09 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Apr 2025 13:08:50 +0800 Subject: [PATCH 0293/1191] Revert to conventional momentum, keep transformed_params --- egs/librispeech/ASR/zipformer/optim.py | 113 +++---------------------- 1 file changed, 11 insertions(+), 102 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9166b9ff59..51aadc2e40 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -156,107 +156,19 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] - - debug = False + beta1 = group["beta1"] try: stored_delta = state["delta"] - if p.numel() != p.shape[0]: - alphas = state["alphas"] - betas = state["betas"] - if debug: - decayed_params = state["decayed_params"] except KeyError as e: assert step < 2 - if p.numel() == p.shape[0]: - # scalar. use conventional momentum. - stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["delta"] = stored_delta - else: - # e.g. group["betas"] = (.96, .9984, .99936), - # group["scales"] = (4., 8., 16.), - betas = torch.tensor(group["betas"]).to(device=p.device) - alphas = torch.tensor(group["scales"]).to(device=p.device) * (1-betas) - for _ in range(p.ndim): - betas, alphas = betas.unsqueeze(-1), alphas.unsqueeze(-1) - - stored_delta = torch.zeros(len(betas), *p.shape, device=p.device, dtype=torch.float) - state["delta"] = stored_delta - state["betas"] = betas - state["alphas"] = alphas - if debug: - decayed_params = torch.zeros_like(stored_delta) - decayed_params[:] = p - state["decayed_params"] = decayed_params - - if p.numel() == p.shape[0]: # scalar. use conventional momentum. - beta = 0.9 - stored_delta.mul_(beta).add_(delta, alpha=(1-beta)) - lr = lr * group["scalar_lr_scale"] - return -lr * stored_delta - - max_change = 0.1 - if max_change > 0.0: - # Limit the beta values in an LR-dependent way, that doesn't allow us to - # hang onto "momentum" past the point when we expect the parameter to have - # significantly changed by more than "max_change" of relative change. - # the formulas are a bit approximate but we can just tune max_change to - # compensate. OK, so we assume that the update formula will be as follows: - # stored_delta *= betas - # stored_delta += delta - # return -lr * (delta + (stored_delta * alphas).sum(dim=0)) - # and we limit the betas so that, considering only the part of the change - # from "this beta value" (i.e. from this accumulator, ignoring the other betas and - # the implicit "beta=0,alpha=1") we don't keep changes from gradients that - # were accumulated when the parameter was too significantly different. - # For this we assume gradient independence between steps, i.e. independence between the - # delta values returned by basic_step. So the "effective LR" for each beta-accumulator, - # if we count the whole decay sequence, will be: - # lr_eff = lr * alpha * 1/(1-beta) - # .. and the "extra variance" that this adds to the parameter value is lr_eff^2. - # The parameter values have been normalized by forward_transform_param so that - # their variance is about 1, so this "extra variance" can be interpreted as a - # "relative change in variance", and the corresponding "relative change in parameter", - # interpreted in a cosine-of-angle sense - # would be sqrt(relative-change-in-variance) ~ 0.5 (relative-change-in-variance) - # [we shrink the parameter down by 1/sqrt(relative-change-in-variance) to renormalize - # it, and we can assume the gradient - # was orthogonal to the parameter on average.] - # Now, the time period over which we have to measure parameter changes is about - # 1/(1-beta); this is the "decay time" of the accumulator. So the limiting-to-max-change - # equation becomes: - # 1/(1-beta). 0.5 lr_eff^2 <= max_change - # and expanding: - # 1/(1-beta). 0.5 (lr * alpha * 1/(1-beta))^2 <= max_change - # (lr * alpha)^2 <= max_change (1-beta)^3 - # (1-beta) >= ((0.5/max_change) * (lr * alpha)**2) ** 1/3. - # beta <= 1 - ((0.5/max_change) * (lr * alpha)**2) ** 1/3. - # so: - ceil = 1. - (((0.5 / max_change) * lr**2) * (alphas**2)) ** 0.3333 - if random.random() < 0.002 and not debug: - logging.info(f"lr={lr}, shape={list(p.shape)}, ceil={ceil.flatten()}, betas={betas.flatten()}") - betas = torch.minimum(betas, ceil) - - - if debug: - decayed_params *= betas - decayed_params += (1-betas) * p - if random.random() < 0.001: - dims = tuple(range(1, decayed_params.ndim)) - p = p[:, :-2] # get rid of the padding element and the log-scale parameter (which may be too large) - decayed_params = decayed_params[:, :, :-2] - cosine = ((p * decayed_params).sum(dim=dims) / - ((p*p).sum() * (decayed_params*decayed_params).sum(dim=dims)).sqrt()) - param_change = 1 - cosine - logging.info(f"lr={lr}, shape={list(p.shape)}, ceil={ceil.flatten()}, betas={betas.flatten()}, max_change={max_change}, change={param_change}") - - - - stored_delta *= betas - stored_delta += delta - - return -lr * (delta + (stored_delta * alphas).sum(dim=0)) + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + + + stored_delta.mul_(beta1).add_(delta, alpha=(1-beta1)) + return -lr * stored_delta @@ -473,9 +385,8 @@ def __init__( params, lr=3e-02, clipping_scale=None, + beta1=0.9, beta2=0.98, - betas=(.95, .9975, .999875), - scales=(4., 8., 16.), scalar_lr_scale=0.2, scaling_lr_scale=0.1, eps=1.0e-08, @@ -488,13 +399,11 @@ def __init__( debug_interval=0, ): - assert len(betas) == len(scales) defaults = dict( lr=lr, clipping_scale=clipping_scale, + beta1=beta1, beta2=beta2, - betas=betas, - scales=scales, scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, @@ -1352,7 +1261,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = TransformedAdam(m.named_parameters(), lr=0.005, clipping_scale=2.0, eps=1.0e-20) + optim = TransformedAdam(m.named_parameters(), lr=0.075, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() @@ -1413,7 +1322,7 @@ def _test_transform_params(): p = scale * torch.randn(*shape) q = forward_transform_param(group, p) r = reverse_transform_param(group, q, p.shape) - assert torch.allclose(p, r), (p, q, r) + assert torch.allclose(p, r, atol=1.0e-03), (p, q, r) if __name__ == "__main__": From 83ad9c3d5ea500d0cf74f897445406c48ebe9a0a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Apr 2025 12:21:54 +0800 Subject: [PATCH 0294/1191] Increase beta1 to 0.98, add direct=0.05, decrease scalar_lr_scale from .2 to .1 --- egs/librispeech/ASR/zipformer/optim.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 51aadc2e40..4a866ca834 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -156,7 +156,8 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] - beta1 = group["beta1"] + beta1 = min(group["beta1"], 1. - 1. / (10. + step)) + direct = group["direct"] try: stored_delta = state["delta"] @@ -167,8 +168,8 @@ def momentum_step(group, p, state, grad): state["delta"] = stored_delta - stored_delta.mul_(beta1).add_(delta, alpha=(1-beta1)) - return -lr * stored_delta + stored_delta.mul_(beta1).add_(delta, alpha=(1-beta1) * (1-direct)) + return -lr * (stored_delta + direct * delta) @@ -385,9 +386,10 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.9, + beta1=0.98, + direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - scalar_lr_scale=0.2, + scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, @@ -403,6 +405,7 @@ def __init__( lr=lr, clipping_scale=clipping_scale, beta1=beta1, + direct=direct, beta2=beta2, scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, @@ -1261,7 +1264,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = TransformedAdam(m.named_parameters(), lr=0.075, clipping_scale=2.0, eps=1.0e-20) + optim = TransformedAdam(m.named_parameters(), lr=0.06, clipping_scale=2.0, eps=1.0e-20) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() @@ -1322,7 +1325,7 @@ def _test_transform_params(): p = scale * torch.randn(*shape) q = forward_transform_param(group, p) r = reverse_transform_param(group, q, p.shape) - assert torch.allclose(p, r, atol=1.0e-03), (p, q, r) + assert torch.allclose(p, r, atol=1.0e-02), (p, q, r) if __name__ == "__main__": From 39b6e146454cc4dbf3f98596e28939dec1eb0d12 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Apr 2025 20:03:15 +0800 Subject: [PATCH 0295/1191] Add EdenBounce, cosine-like lr-schedule modifier. Default: bounce_radius 500, bounce_perioe 4000, bounce_bottom: 0.333 --- egs/librispeech/ASR/zipformer/optim.py | 93 ++++++++++++++++++++++++++ egs/librispeech/ASR/zipformer/train.py | 15 +---- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 4a866ca834..b58ace5a04 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1049,6 +1049,71 @@ def get_lr(self): return [x * factor * warmup_factor for x in self.base_lrs] + +class EdenBounce(LRScheduler): + """ + Somewhat the Eden scheduler, but simpler than Eden because it does not use the notion of epoch, + only batches. + + The "bounce" refers to a cosine-schedule-like feature, which we implement as a kind of + extension of the notion of the "warmup period". This means that if we are within + "bounce_period" batches of one of the "bounce times", we slow down the learning rate + by multiplying in a sawtooth "bounce factor" which is 1.0 if we are far from any + bounce period and goes down to a value of "bounce_bottom", e.g. 0.25, if we are at exactly + at the "bounce tim". + + The basic formula (before bounce-factor) is: + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + bounce_radius: the distance from the center to the edge of each sawtooth + cutout. + bounce_period: the distance from one sawtooth to the next. + bounce_bottom: a factor betwedn 0 and 1, which defines the bottom of each + sawtooth in the bounce-factor. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + bounce_radius: Union[int, float] = 500.0, + bounce_period: float = 4000.0, + bounce_bottom: float = 0.333, + verbose: bool = False, + ): + super().__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.bounce_radius = bounce_radius + self.bounce_period = bounce_period + self.bounce_bottom = bounce_bottom + self.verbose = verbose + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.5 + + closest_bounce_center = self.bounce_period * int(0.5 +self.batch / self.bounce_period) + bounce_distance = abs(self.batch - closest_bounce_center) + + bounce_factor = ( + 1.0 if bounce_distance > self.bounce_radius + else self.bounce_bottom + + (1.0 - self.bounce_bottom) * (bounce_distance / self.bounce_radius) + ) + + return [x * factor * bounce_factor for x in self.base_lrs] + + + + + def _test_eden(): m = torch.nn.Linear(100, 100) optim = TransformedAdam(m.parameters(), lr=0.03) @@ -1073,6 +1138,33 @@ def _test_eden(): logging.info(f"last lr = {scheduler.get_last_lr()}") logging.info(f"state dict = {scheduler.state_dict()}") +def _test_eden_bounce(): + m = torch.nn.Linear(100, 100) + optim = TransformedAdam(m.parameters(), lr=0.01) + + scheduler = EdenBounce(optim, lr_batches=10000, + verbose=True, + bounce_period=100, + bounce_radius=10) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + # This is included mostly as a baseline for TransformedAdam. class Eve(Optimizer): @@ -1345,6 +1437,7 @@ def _test_transform_params(): else: hidden_dim = 200 + _test_eden_bounce() _test_transform_params() _test_scaled_adam(hidden_dim) _test_eden() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 2a29a2377f..22d82334a1 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel -from optim import Eden, TransformedAdam +from optim import EdenBounce, TransformedAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -400,19 +400,11 @@ def get_parser(): parser.add_argument( "--lr-batches", type=float, - default=7500, + default=10000, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - parser.add_argument( "--ref-duration", type=float, @@ -1372,7 +1364,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = EdenBounce(optimizer, params.lr_batches) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1481,7 +1473,6 @@ def remove_short_and_long_utt(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) From b6efd00df2fcb3fe9db9921f53050fc0e5eda04a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Apr 2025 23:34:10 +0800 Subject: [PATCH 0296/1191] Apply the 1-beta1 scale in two halves. --- egs/librispeech/ASR/zipformer/optim.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b58ace5a04..9b600e5405 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -168,8 +168,9 @@ def momentum_step(group, p, state, grad): state["delta"] = stored_delta - stored_delta.mul_(beta1).add_(delta, alpha=(1-beta1) * (1-direct)) - return -lr * (stored_delta + direct * delta) + sqrt_scale=(1-beta1) ** 0.5 + stored_delta.mul_(beta1).add_(delta, alpha=sqrt_scale) + return ((-lr * (1-direct) * sqrt_scale) * stored_delta) + ((-lr * direct) * delta) From 0f9353ad3be7ec11a89e8ef174088924cbdf20ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Apr 2025 23:38:55 +0800 Subject: [PATCH 0297/1191] Reduce scalar_lr_scale to 0.05. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9b600e5405..74b8f4bab0 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -390,7 +390,7 @@ def __init__( beta1=0.98, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - scalar_lr_scale=0.1, + scalar_lr_scale=0.05, scaling_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, From 0e31c234593514532e0716fcf7340b10a9386dc5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Apr 2025 23:46:52 +0800 Subject: [PATCH 0298/1191] revert scalar_lr_scale change; add factor of 0.5 in formula for beta1 --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 74b8f4bab0..d7dd2b61d4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -156,7 +156,7 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] - beta1 = min(group["beta1"], 1. - 1. / (10. + step)) + beta1 = min(group["beta1"], 1. - 1. / (10. + 0.5 * step)) direct = group["direct"] try: @@ -390,7 +390,7 @@ def __init__( beta1=0.98, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - scalar_lr_scale=0.05, + scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, weight_min_rms=0.005, From 9abfaadd120ec9a4ba321e12209dc2ed9d68d093 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Apr 2025 00:04:29 +0800 Subject: [PATCH 0299/1191] Put whole 1-beta1 scale after averaging --- egs/librispeech/ASR/zipformer/optim.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d7dd2b61d4..058c69c187 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -167,10 +167,8 @@ def momentum_step(group, p, state, grad): stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) state["delta"] = stored_delta - - sqrt_scale=(1-beta1) ** 0.5 - stored_delta.mul_(beta1).add_(delta, alpha=sqrt_scale) - return ((-lr * (1-direct) * sqrt_scale) * stored_delta) + ((-lr * direct) * delta) + stored_delta.mul_(beta1).add_(delta) + return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) From 0faafe823f7fc4a79217c0a6c8f547143401d32f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Apr 2025 17:22:39 +0800 Subject: [PATCH 0300/1191] Apply normalize module to bypass in Zipformer2Encoder --- egs/librispeech/ASR/zipformer/train.py | 9 ++++++-- egs/librispeech/ASR/zipformer/zipformer.py | 24 ++++++++++++++++------ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 22d82334a1..e953171898 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -644,17 +644,22 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: # by a factor of 2, and most of the encoder stacks will run at a lower # sampling rate. output_downsampling_factor = 2 + encoder_embed_dim = max(_to_int_tuple(params.encoder_dim)) // output_downsampling_factor encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, - out_channels=max(_to_int_tuple(params.encoder_dim)) // output_downsampling_factor, + out_channels=encoder_embed_dim, dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), ) return encoder_embed def get_encoder_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + # the formula below is just for historical reasons, could be anything >= min(params.encoder_dim). + encoder_embed_dim = max(_to_int_tuple(params.encoder_dim)) // output_downsampling_factor encoder = Zipformer2( - output_downsampling_factor=2, + input_dim=encoder_embed_dim, + output_downsampling_factor=output_downsampling_factor, downsampling_factor=_to_int_tuple(params.downsampling_factor), num_encoder_layers=_to_int_tuple(params.num_encoder_layers), encoder_dim=_to_int_tuple(params.encoder_dim), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 55a02bec6f..696e795ba8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -92,6 +92,7 @@ class Zipformer2(EncoderInterface): """ def __init__( self, + input_dim: int, output_downsampling_factor: int = 2, downsampling_factor: Tuple[int] = (2, 4), encoder_dim: Union[int, Tuple[int]] = 384, @@ -147,7 +148,6 @@ def _to_tuple(x): num_encoders = len(downsampling_factor) cur_downsample = 1 - input_dim = max(encoder_dim) // output_downsampling_factor # caution: some changes we made for this break the streaming, later we'll try to fix this. encoders_downsampling_factors = [ ] @@ -188,8 +188,8 @@ def set_downsample_factor(cur_downsample, ds): encoder = Zipformer2Encoder( encoder_layer, num_encoder_layers[i], + dim=cur_downsample*input_dim, pos_dim=pos_dim, - dropout=dropout, ) encoder.encoder_index = i # <-- will be used in streaming_forward encoders.append(encoder) @@ -705,21 +705,25 @@ class Zipformer2Encoder(nn.Module): Args: encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). pos_dim: the dimension for the relative positional encoding +dropout: Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) + + """ def __init__( self, encoder_layer: nn.Module, num_layers: int, + dim: int, pos_dim: int, - dropout: float, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( @@ -730,7 +734,11 @@ def __init__( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers - self.copy_bypass = nn.Identity() # in case we are dumping diagnostics. + + bypass_dim = dim - encoder_layer.embed_dim + assert bypass_dim >= 0 + if bypass_dim > 0: + self.norm_bypass = ExpNorm(bypass_dim) self.whiten = Whiten( num_groups=1, @@ -784,7 +792,9 @@ def forward( src = self.whiten(src) if num_channels > layer_dim: - bypass = self.copy_bypass(bypass) + # we pass the bypass through the norm layer mainly to prevent the model from having an incentive + # to pass the more informative feature dimensions through the bypass. + bypass = self.norm_bypass(bypass) src = torch.cat((src, bypass), dim=-1) return src @@ -1980,14 +1990,16 @@ def _test_zipformer_main(causal: bool = False): seq_len = 20 # Just make sure the forward pass runs. + input_dim = 50 + c = Zipformer2( + input_dim=input_dim, encoder_dim=(64, 96), num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), left_context_frames=(64,), ) - input_dim = 96 // 2 # this makes little sense, it relates to how the code used to work. batch_size = 5 seq_len = 21 From c5dc07c482ca04dedfb95871462a112fadb12500 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Apr 2025 19:28:39 +0800 Subject: [PATCH 0301/1191] Add command line option for --embed-dim --- egs/librispeech/ASR/zipformer/train.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index e953171898..9db20259d7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -142,6 +142,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Downsampling factor for each stack of encoder layers.", ) + parser.add_argument( + "--embed-dim", + type=int, + default=192, + help="Output dimension of frontend, also determines bypass dimensions in zipformer layers.", + ) + parser.add_argument( "--feedforward-dim", type=str, @@ -643,23 +650,18 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: # In the normal configuration, we will downsample once more at the end # by a factor of 2, and most of the encoder stacks will run at a lower # sampling rate. - output_downsampling_factor = 2 - encoder_embed_dim = max(_to_int_tuple(params.encoder_dim)) // output_downsampling_factor encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, - out_channels=encoder_embed_dim, + out_channels=params.embed_dim, dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), ) return encoder_embed def get_encoder_model(params: AttributeDict) -> nn.Module: - output_downsampling_factor = 2 - # the formula below is just for historical reasons, could be anything >= min(params.encoder_dim). - encoder_embed_dim = max(_to_int_tuple(params.encoder_dim)) // output_downsampling_factor encoder = Zipformer2( - input_dim=encoder_embed_dim, - output_downsampling_factor=output_downsampling_factor, + input_dim=params.embed_dim, + output_downsampling_factor=2, downsampling_factor=_to_int_tuple(params.downsampling_factor), num_encoder_layers=_to_int_tuple(params.num_encoder_layers), encoder_dim=_to_int_tuple(params.encoder_dim), From 3beac5f1c5d0e1a34109216312ba6b5e3df5e798 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Apr 2025 19:55:40 +0800 Subject: [PATCH 0302/1191] Bug fix regarding encoder_dim --- egs/librispeech/ASR/zipformer/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9db20259d7..f3f0c5b089 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -692,8 +692,9 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), + encoder_dim=params.embed_dim * output_downsampling_factor, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -709,7 +710,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module: attention_dim=params.attention_decoder_attention_dim, num_heads=params.attention_decoder_num_heads, feedforward_dim=params.attention_decoder_feedforward_dim, - memory_dim=max(_to_int_tuple(params.encoder_dim)), + memory_dim=params.embed_dim * output_downsampling_factor, sos_id=params.sos_id, eos_id=params.eos_id, ignore_id=params.ignore_id, @@ -740,13 +741,14 @@ def get_model(params: AttributeDict) -> nn.Module: else: attention_decoder = None + output_downsampling_factor = 2 model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, attention_decoder=attention_decoder, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), + encoder_dim=output_downsampling_factor * params.embed_dim, decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, From 22dcfc673f58cb182d94ebb253ed3a2d5565d3b7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Apr 2025 12:09:26 +0800 Subject: [PATCH 0303/1191] Reduce rand_floor from 0.25 to 0.0. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a2fc49e65b..9249bd174b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -464,7 +464,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.25, + rand_floor: FloatLike = 0.0, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From cc259297d49b384845273376b6638dbb5c25b0f0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Apr 2025 14:10:35 +0800 Subject: [PATCH 0304/1191] remove ScaleLimiter. --- egs/librispeech/ASR/zipformer/subsampling.py | 4 ---- egs/librispeech/ASR/zipformer/zipformer.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 7763e16a56..f47befd325 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -221,8 +221,6 @@ def __init__( self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=4.0) - self.out_limiter = ScaleLimiter(max_scale=4.0) - # use a larger than normal grad_scale on this whitening module; there is # only one such module, so there is not a concern about adding together # many copies of this extra gradient term. @@ -265,7 +263,6 @@ def forward( # now x: (N, (T-7)//2, out_width * layer3_channels)) x = self.out(x) - x = self.out_limiter(x) # Now x is of shape (N, (T-7)//2, odim) x = self.out_whiten(x) x = self.out_norm(x) @@ -319,7 +316,6 @@ def streaming_forward( x = self.out(x) # Now x is of shape (N, T', odim) - x = self.out_limiter(x) x = self.out_norm(x) if torch.jit.is_scripting() or torch.jit.is_tracing(): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 696e795ba8..e6ab159554 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -520,8 +520,6 @@ def __init__( self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) for _ in range(2) ] - self.scale_limiter = ScaleLimiter(max_scale=2.0) - self.norm = ExpNorm(embed_dim) @@ -574,8 +572,6 @@ def forward( src = self.bypass(src_orig, src) - src = self.scale_limiter(src) - return self.norm(src) def streaming_forward( From b0f982f4446b502ebf32679abad3fe5f65188aaf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Apr 2025 14:55:45 +0800 Subject: [PATCH 0305/1191] Introduce PredictLoss, with loss scale 0.01. --- egs/librispeech/ASR/zipformer/scaling.py | 77 +++++++++++++++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 11 +++- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9249bd174b..437f2c1c01 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,6 +565,82 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans +class PredictFunction(torch.autograd.Function): + # cross-prediction thing to be used in conjunction with CR-CTC or anything else + # where the batch is repeated twice but with different spec-augment. + # assume channel dim is dim -1, batch_dim is specified by user. + @staticmethod + @custom_fwd + def forward(ctx, x, pred_weight, proj_weight, loss_scale, batch_dim, name): + ctx.save_for_backward(x, pred_weight, proj_weight) + ctx.loss_scale = loss_scale # loss_scale is relative to existing grad. + ctx.name = name + ctx.batch_dim = batch_dim + return x + + @staticmethod + @custom_bwd + def backward(ctx, x_grad): + x, pred_weight, proj_weight = ctx.saved_tensors + + batch_size = x.shape[ctx.batch_dim] + assert batch_size % 2 == 0, "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." + + + # is_top1 true if this is the highest dot-product with x_proj. + x_proj = torch.matmul(x, proj_weight.t()) + top1_indexes = torch.max(x_proj, dim=-1, keepdim=True)[1] + top1_indexes = torch.roll(top1_indexes, batch_size // 2, ctx.batch_dim) + + # take loss_scale from other copy of this utterance/item. + loss_scale = ctx.loss_scale * x_grad.norm(dim=-1, keepdim=True) + loss_scale = torch.roll(loss_scale, batch_size // 2, ctx.batch_dim) + + x = x.detach() + pred_weight = pred_weight.detach() + x.requires_grad = True + pred_weight.requires_grad = True + + with torch.enable_grad(): + x_pred = torch.matmul(x, pred_weight.t()) + logprobs = x_pred.log_softmax(dim=-1) + loss = -torch.gather(logprobs, dim=-1, index=top1_indexes) + loss_scaled = loss * loss_scale + loss_scaled.backward(gradient=torch.ones_like(loss_scaled)) + if random.random() < 0.002: + logging.info(f"name={ctx.name}, mean loss before scale = {loss.mean()}") + + return x_grad + x.grad, pred_weight.grad, None, None, None, None + +class PredictLoss(nn.Module): + """ + Adds an auxiliary loss based on predicting the top-1 of inner-products with + a random set of vectors, of the "other copy" of each utterance (it assumes + you are doing something like CR-CTC so + """ + def __init__(self, + num_channels: int, + num_centers: int, + loss_scale: FloatLike = 0.01, + batch_dim: int = 0): + super().__init__() + self.register_buffer('proj_weight', + torch.randn(num_centers, num_channels), + persistent=True) + self.pred_weight = nn.Parameter(torch.zeros(num_centers, num_channels)) + self.loss_scale = loss_scale + self.batch_dim = batch_dim + self.name = None # will be set from training code + + + def forward(self, + x: Tensor) -> Tensor: + return PredictFunction.apply(x, self.pred_weight, self.proj_weight, + self.loss_scale, self.batch_dim, + self.name) + + + class OrthogonalLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd @@ -1755,7 +1831,6 @@ class ActivationDropoutAndLinear(torch.nn.Module): efficient if there are modules before this one that cache the input for their backprop (e.g. Balancer or Whiten). """ - def __init__( self, in_channels: int, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e6ab159554..322c53b048 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -40,6 +40,7 @@ convert_num_channels, limit_param_value, penalize_abs_values_gt, + PredictLoss, softmax, ) from torch import Tensor, nn @@ -522,6 +523,8 @@ def __init__( self.norm = ExpNorm(embed_dim) + self.predict_loss = PredictLoss(embed_dim, 256, loss_scale=0.01, batch_dim=1) + def forward( self, @@ -572,7 +575,11 @@ def forward( src = self.bypass(src_orig, src) - return self.norm(src) + src = self.norm(src) + + src = self.predict_loss(src) + + return src def streaming_forward( self, @@ -1997,7 +2004,7 @@ def _test_zipformer_main(causal: bool = False): left_context_frames=(64,), ) - batch_size = 5 + batch_size = 6 # make it even, as PredictLoss requires even batch size. seq_len = 21 # Just make sure the forward pass runs. f = c( From 5d0d670b16ede2edf04e0783240640d1ca9a8b22 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Apr 2025 16:09:37 +0800 Subject: [PATCH 0306/1191] Take into account x norm in loss scale --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 437f2c1c01..8c6fd89e10 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -593,7 +593,7 @@ def backward(ctx, x_grad): top1_indexes = torch.roll(top1_indexes, batch_size // 2, ctx.batch_dim) # take loss_scale from other copy of this utterance/item. - loss_scale = ctx.loss_scale * x_grad.norm(dim=-1, keepdim=True) + loss_scale = ctx.loss_scale * x_grad.norm(dim=-1, keepdim=True) * x.norm(dim=-1, keepdim=True) loss_scale = torch.roll(loss_scale, batch_size // 2, ctx.batch_dim) x = x.detach() From a8fdac12cb259db4ae72c3c0682026136d45ebfd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Apr 2025 16:43:53 +0800 Subject: [PATCH 0307/1191] Bug fix to reduce nans --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8c6fd89e10..2b20858a36 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -593,8 +593,9 @@ def backward(ctx, x_grad): top1_indexes = torch.roll(top1_indexes, batch_size // 2, ctx.batch_dim) # take loss_scale from other copy of this utterance/item. - loss_scale = ctx.loss_scale * x_grad.norm(dim=-1, keepdim=True) * x.norm(dim=-1, keepdim=True) - loss_scale = torch.roll(loss_scale, batch_size // 2, ctx.batch_dim) + with torch.cuda.amp.autocast(enabled=False): + loss_scale = ctx.loss_scale * x_grad.to(torch.float).norm(dim=-1, keepdim=True) * x.to(torch.float).norm(dim=-1, keepdim=True) + loss_scale = torch.roll(loss_scale, batch_size // 2, ctx.batch_dim) x = x.detach() pred_weight = pred_weight.detach() From ee7498683914f3f98793cb1d4f8de5019e0c52f9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Apr 2025 10:23:38 +0800 Subject: [PATCH 0308/1191] Add back scale_limiter --- egs/librispeech/ASR/zipformer/scaling.py | 16 ++++++++++------ egs/librispeech/ASR/zipformer/zipformer.py | 4 ++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2b20858a36..c9931939fd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -596,6 +596,7 @@ def backward(ctx, x_grad): with torch.cuda.amp.autocast(enabled=False): loss_scale = ctx.loss_scale * x_grad.to(torch.float).norm(dim=-1, keepdim=True) * x.to(torch.float).norm(dim=-1, keepdim=True) loss_scale = torch.roll(loss_scale, batch_size // 2, ctx.batch_dim) + loss_scale = loss_scale.clamp(max=0.1) # TEMP x = x.detach() pred_weight = pred_weight.detach() @@ -611,7 +612,9 @@ def backward(ctx, x_grad): if random.random() < 0.002: logging.info(f"name={ctx.name}, mean loss before scale = {loss.mean()}") - return x_grad + x.grad, pred_weight.grad, None, None, None, None + + extra_grad = torch.nan_to_num(x.grad, nan=0.0).clamp(min=-0.1, max=0.1) + return x_grad + extra_grad, pred_weight.grad, None, None, None, None class PredictLoss(nn.Module): """ @@ -622,14 +625,15 @@ class PredictLoss(nn.Module): def __init__(self, num_channels: int, num_centers: int, - loss_scale: FloatLike = 0.01, + loss_scale: FloatLike = ScheduledFloat((0.0, 1.0e-06), (1000.0, 1.0e-05), (2000.0, 0.01)), batch_dim: int = 0): super().__init__() + scale = num_channels ** -0.5 self.register_buffer('proj_weight', - torch.randn(num_centers, num_channels), + scale * torch.randn(num_centers, num_channels), persistent=True) - self.pred_weight = nn.Parameter(torch.zeros(num_centers, num_channels)) - self.loss_scale = loss_scale + self.pred_weight = nn.Parameter(scale * torch.randn(num_centers, num_channels)) + self.loss_scale = copy.deepcopy(loss_scale) self.batch_dim = batch_dim self.name = None # will be set from training code @@ -637,7 +641,7 @@ def __init__(self, def forward(self, x: Tensor) -> Tensor: return PredictFunction.apply(x, self.pred_weight, self.proj_weight, - self.loss_scale, self.batch_dim, + float(self.loss_scale), self.batch_dim, self.name) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 322c53b048..604c529126 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -521,6 +521,8 @@ def __init__( self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) for _ in range(2) ] + self.scale_limiter = ScaleLimiter(max_scale=2.0) + self.norm = ExpNorm(embed_dim) self.predict_loss = PredictLoss(embed_dim, 256, loss_scale=0.01, batch_dim=1) @@ -575,6 +577,8 @@ def forward( src = self.bypass(src_orig, src) + src = self.scale_limiter(src) + src = self.norm(src) src = self.predict_loss(src) From ff175ad8f585f1785e8a636c6c41a393e3e272d7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Apr 2025 10:38:59 +0800 Subject: [PATCH 0309/1191] revert EdenBounce to Eden2, make the lr-batches of 17500 the default, add adjustment for duration_ratio; a fix to a debug message. --- egs/librispeech/ASR/zipformer/optim.py | 89 ---------------------- egs/librispeech/ASR/zipformer/train.py | 17 ++++- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 3 files changed, 15 insertions(+), 93 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 058c69c187..9186632839 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1049,67 +1049,6 @@ def get_lr(self): -class EdenBounce(LRScheduler): - """ - Somewhat the Eden scheduler, but simpler than Eden because it does not use the notion of epoch, - only batches. - - The "bounce" refers to a cosine-schedule-like feature, which we implement as a kind of - extension of the notion of the "warmup period". This means that if we are within - "bounce_period" batches of one of the "bounce times", we slow down the learning rate - by multiplying in a sawtooth "bounce factor" which is 1.0 if we are far from any - bounce period and goes down to a value of "bounce_bottom", e.g. 0.25, if we are at exactly - at the "bounce tim". - - The basic formula (before bounce-factor) is: - lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - bounce_radius: the distance from the center to the edge of each sawtooth - cutout. - bounce_period: the distance from one sawtooth to the next. - bounce_bottom: a factor betwedn 0 and 1, which defines the bottom of each - sawtooth in the bounce-factor. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - bounce_radius: Union[int, float] = 500.0, - bounce_period: float = 4000.0, - bounce_bottom: float = 0.333, - verbose: bool = False, - ): - super().__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.bounce_radius = bounce_radius - self.bounce_period = bounce_period - self.bounce_bottom = bounce_bottom - self.verbose = verbose - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.5 - - closest_bounce_center = self.bounce_period * int(0.5 +self.batch / self.bounce_period) - bounce_distance = abs(self.batch - closest_bounce_center) - - bounce_factor = ( - 1.0 if bounce_distance > self.bounce_radius - else self.bounce_bottom - + (1.0 - self.bounce_bottom) * (bounce_distance / self.bounce_radius) - ) - - return [x * factor * bounce_factor for x in self.base_lrs] - - @@ -1137,33 +1076,6 @@ def _test_eden(): logging.info(f"last lr = {scheduler.get_last_lr()}") logging.info(f"state dict = {scheduler.state_dict()}") -def _test_eden_bounce(): - m = torch.nn.Linear(100, 100) - optim = TransformedAdam(m.parameters(), lr=0.01) - - scheduler = EdenBounce(optim, lr_batches=10000, - verbose=True, - bounce_period=100, - bounce_radius=10) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - # This is included mostly as a baseline for TransformedAdam. class Eve(Optimizer): @@ -1436,7 +1348,6 @@ def _test_transform_params(): else: hidden_dim = 200 - _test_eden_bounce() _test_transform_params() _test_scaled_adam(hidden_dim) _test_eden() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f3f0c5b089..8dbe22ab49 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel -from optim import EdenBounce, TransformedAdam +from optim import Eden2, TransformedAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -116,6 +116,17 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: ) +def get_adjusted_lr_batches(params: AttributeDict) -> float: + # returns an adjusted form of the "lr_batches" parameter used to set the learning + # rate in the Eden2 scheduler. If we have larger batch-sizes and/or world size, + # we want to decrease the learning rate faster because the grads will be less + # noisy, so we want a smallar lr_batches. + duration_ratio = (params.max_duration * params.world_size) / params.ref_duration + lr_batches = params.lr_batches * (duration_ratio ** -0.5) + logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") + return lr_batches + + def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module @@ -407,7 +418,7 @@ def get_parser(): parser.add_argument( "--lr-batches", type=float, - default=10000, + default=17500, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) @@ -1373,7 +1384,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = EdenBounce(optimizer, params.lr_batches) + scheduler = Eden2(optimizer, get_adjusted_lr_batches(params.lr_batches)) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 696e795ba8..7fbd6e3e48 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1349,7 +1349,7 @@ def forward( if torch.jit.is_scripting() or torch.jit.is_tracing(): pass - elif random.random() < 0.001 and not self.training: + elif random.random() < 0.001: self._print_attn_entropy(attn_weights) attn_weights = nn.functional.dropout( From 14a10e1d7817591d468424f3e50fcd60222d948e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Apr 2025 10:52:13 +0800 Subject: [PATCH 0310/1191] Bug fix; more comment. --- egs/librispeech/ASR/zipformer/train.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 8dbe22ab49..3839449c8c 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -118,9 +118,15 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: def get_adjusted_lr_batches(params: AttributeDict) -> float: # returns an adjusted form of the "lr_batches" parameter used to set the learning - # rate in the Eden2 scheduler. If we have larger batch-sizes and/or world size, - # we want to decrease the learning rate faster because the grads will be less - # noisy, so we want a smallar lr_batches. + # rate in the Eden2 scheduler. + # We want the final LR to be based on the geometric mean of "how much data we + # have seen" and "how many batches we have seen". + # an easier way to look at it is this: the formula for learning rate depends + # on (cur_batch / lr_batches). if we write this as: + # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches + # then the numerator is a geometric mean of "how many batches we have seen" + # and "how much data we have seen". We can achieve this by setting + # lr_batches = params.lr_batches * (duration_ratio ** -0.5). duration_ratio = (params.max_duration * params.world_size) / params.ref_duration lr_batches = params.lr_batches * (duration_ratio ** -0.5) logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") @@ -1384,7 +1390,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Eden2(optimizer, get_adjusted_lr_batches(params.lr_batches)) + scheduler = Eden2(optimizer, get_adjusted_lr_batches(params)) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 87b464b9896e8e7f5ee3182b1be99c6687edc3b7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Apr 2025 12:44:28 +0800 Subject: [PATCH 0311/1191] Revert rand_floor from .0 to .25 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 9249bd174b..a2fc49e65b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -464,7 +464,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.0, + rand_floor: FloatLike = 0.25, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From 0edc4335be886352273e4f8cadc5310567ed9ebd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Apr 2025 12:51:38 +0800 Subject: [PATCH 0312/1191] Rename bypass module to Residual module; add Residual module to Zipformer2Encoder and remove the bypass-norm. --- egs/librispeech/ASR/zipformer/zipformer.py | 49 +++++++++++----------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7fbd6e3e48..95c7929554 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -494,8 +494,8 @@ def __init__( self.name = None # will be set from training loop self.randomize_scale = copy.deepcopy(randomize_scale) - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule( + # self.bypass implements layer skipping as well as learnable scale on a residual term; see its default values. + self.residual = ResidualModule( embed_dim, ) @@ -572,7 +572,7 @@ def forward( src = src + self.feed_forward3(src) - src = self.bypass(src_orig, src) + src = self.residual(src_orig, src) src = self.scale_limiter(src) @@ -662,8 +662,6 @@ def streaming_forward( src = src + self.feed_forward2(src) - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) self_attn, cached_val2 = self.self_attn2.streaming_forward( src, @@ -684,7 +682,7 @@ def streaming_forward( src = self.norm(src) - src = self.bypass(src_orig, src) + src = self.residual(src_orig, src) src = self.norm(src) @@ -735,10 +733,11 @@ def __init__( ) self.num_layers = num_layers - bypass_dim = dim - encoder_layer.embed_dim - assert bypass_dim >= 0 - if bypass_dim > 0: - self.norm_bypass = ExpNorm(bypass_dim) + + self.residual = ResidualModule(encoder_layer.embed_dim) + + #bypass_dim = dim - encoder_layer.embed_dim + self.copy_bypass = Identity() self.whiten = Whiten( num_groups=1, @@ -778,6 +777,7 @@ def forward( src, bypass = src[..., :layer_dim], src[..., layer_dim:] + src_orig = src for i, mod in enumerate(self.layers): src = mod( src, @@ -789,12 +789,11 @@ def forward( # randomize_factor can be viewed as a simple version of an # importance-sampling factor. + src = self.residual(src_orig, src) src = self.whiten(src) if num_channels > layer_dim: - # we pass the bypass through the norm layer mainly to prevent the model from having an incentive - # to pass the more informative feature dimensions through the bypass. - bypass = self.norm_bypass(bypass) + bypass = self.copy_bypass(bypass) src = torch.cat((src, bypass), dim=-1) return src @@ -872,9 +871,9 @@ def streaming_forward( return src, new_states -class BypassModule(nn.Module): +class ResidualModule(nn.Module): """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + An nn.Module that implements a learnable residual scale, and also randomized per-sequence layer-skipping. The bypass is limited during early stages of training to be close to "straight-through", i.e. to not do the bypass operation much initially, in order to force all the modules to learn something. @@ -889,22 +888,22 @@ def __init__( scale_max: FloatLike = 1.0, ): super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.direct_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.skip_rate = copy.deepcopy(skip_rate) self.straight_through_rate = copy.deepcopy(straight_through_rate) self.scale_min = copy.deepcopy(scale_min) self.scale_max = copy.deepcopy(scale_max) - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 corresponds to bypassing - # this module. + def _get_direct_scale(self, batch_size: int): + # returns scale of shape (num_channels,), + # or (batch_size, num_channels,). This is the + # scale on the non-residual term, and 1-direct_scale is the + # scale on the residual. if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.bypass_scale + return self.direct_scale else: ans = limit_param_value( - self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + self.direct_scale, min=float(self.scale_min), max=float(self.scale_max) ) skip_rate = float(self.skip_rate) if skip_rate != 0.0: @@ -927,8 +926,8 @@ def forward(self, src_orig: Tensor, src: Tensor): Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) Returns: something with the same shape as src and src_orig """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale + direct_scale = self._get_direct_scale(src.shape[1]) + return src_orig + (src - src_orig) * direct_scale From 59bd88064934784fe4046b9b2cbb1592605952c8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Apr 2025 21:44:26 +0800 Subject: [PATCH 0313/1191] Make predict loss be returned explicitly --- egs/librispeech/ASR/zipformer/decode.py | 2 +- egs/librispeech/ASR/zipformer/model.py | 8 +-- egs/librispeech/ASR/zipformer/scaling.py | 65 +++++----------------- egs/librispeech/ASR/zipformer/train.py | 12 +++- egs/librispeech/ASR/zipformer/zipformer.py | 19 ++++--- 5 files changed, 43 insertions(+), 63 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 2200b2a672..504d1d94d2 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -452,7 +452,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + encoder_out, encoder_out_lens, _predict_loss = model.forward_encoder(feature, feature_lens) hyps = [] diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 991bf78dff..a8fdbd5e4d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -155,12 +155,12 @@ def forward_encoder( src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens + return encoder_out, encoder_out_lens, predict_loss def forward_ctc( self, @@ -428,7 +428,7 @@ def forward( y = k2.ragged.cat([y, y], axis=0) # Compute encoder outputs - encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -493,7 +493,7 @@ def forward( if use_cr_ctc: reconstruction_loss = reconstruction_loss * 0.5 - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss def forward_reconstruction_loss(self, diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c9931939fd..f94772a477 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,56 +565,24 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -class PredictFunction(torch.autograd.Function): - # cross-prediction thing to be used in conjunction with CR-CTC or anything else - # where the batch is repeated twice but with different spec-augment. - # assume channel dim is dim -1, batch_dim is specified by user. - @staticmethod - @custom_fwd - def forward(ctx, x, pred_weight, proj_weight, loss_scale, batch_dim, name): - ctx.save_for_backward(x, pred_weight, proj_weight) - ctx.loss_scale = loss_scale # loss_scale is relative to existing grad. - ctx.name = name - ctx.batch_dim = batch_dim - return x - - @staticmethod - @custom_bwd - def backward(ctx, x_grad): - x, pred_weight, proj_weight = ctx.saved_tensors +def predict_loss(x: Tensor, pred_weight: Tensor, proj_weight: Tensor, batch_dim: int, name: str) -> Tensor: + batch_size = x.shape[batch_dim] + assert batch_size % 2 == 0, "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." - batch_size = x.shape[ctx.batch_dim] - assert batch_size % 2 == 0, "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." + # is_top1 true if this is the highest dot-product with x_proj. + x_proj = torch.matmul(x, proj_weight.t()) + top1_indexes = torch.max(x_proj, dim=-1, keepdim=True)[1] + top1_indexes = torch.roll(top1_indexes, batch_size // 2, batch_dim) - # is_top1 true if this is the highest dot-product with x_proj. - x_proj = torch.matmul(x, proj_weight.t()) - top1_indexes = torch.max(x_proj, dim=-1, keepdim=True)[1] - top1_indexes = torch.roll(top1_indexes, batch_size // 2, ctx.batch_dim) - - # take loss_scale from other copy of this utterance/item. - with torch.cuda.amp.autocast(enabled=False): - loss_scale = ctx.loss_scale * x_grad.to(torch.float).norm(dim=-1, keepdim=True) * x.to(torch.float).norm(dim=-1, keepdim=True) - loss_scale = torch.roll(loss_scale, batch_size // 2, ctx.batch_dim) - loss_scale = loss_scale.clamp(max=0.1) # TEMP - - x = x.detach() - pred_weight = pred_weight.detach() - x.requires_grad = True - pred_weight.requires_grad = True - - with torch.enable_grad(): - x_pred = torch.matmul(x, pred_weight.t()) - logprobs = x_pred.log_softmax(dim=-1) - loss = -torch.gather(logprobs, dim=-1, index=top1_indexes) - loss_scaled = loss * loss_scale - loss_scaled.backward(gradient=torch.ones_like(loss_scaled)) - if random.random() < 0.002: - logging.info(f"name={ctx.name}, mean loss before scale = {loss.mean()}") + x_pred = torch.matmul(x, pred_weight.t()) + logprobs = x_pred.log_softmax(dim=-1) + loss = -torch.gather(logprobs, dim=-1, index=top1_indexes) + if random.random() < 0.002: + logging.info(f"name={name}, mean loss before scale = {loss.mean()}") - extra_grad = torch.nan_to_num(x.grad, nan=0.0).clamp(min=-0.1, max=0.1) - return x_grad + extra_grad, pred_weight.grad, None, None, None, None + return loss class PredictLoss(nn.Module): """ @@ -625,7 +593,6 @@ class PredictLoss(nn.Module): def __init__(self, num_channels: int, num_centers: int, - loss_scale: FloatLike = ScheduledFloat((0.0, 1.0e-06), (1000.0, 1.0e-05), (2000.0, 0.01)), batch_dim: int = 0): super().__init__() scale = num_channels ** -0.5 @@ -633,16 +600,14 @@ def __init__(self, scale * torch.randn(num_centers, num_channels), persistent=True) self.pred_weight = nn.Parameter(scale * torch.randn(num_centers, num_channels)) - self.loss_scale = copy.deepcopy(loss_scale) self.batch_dim = batch_dim self.name = None # will be set from training code def forward(self, x: Tensor) -> Tensor: - return PredictFunction.apply(x, self.pred_weight, self.proj_weight, - float(self.loss_scale), self.batch_dim, - self.name) + return predict_loss(x, self.pred_weight, self.proj_weight, + self.batch_dim, self.name) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f3f0c5b089..24b5c61fd3 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -481,6 +481,13 @@ def get_parser(): help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", ) + parser.add_argument( + "--predict-loss-scale", + type=float, + default=0.01, + help="Prediction of random k-means after widest zipformer layer" + ) + parser.add_argument( "--time-mask-ratio", type=float, @@ -950,7 +957,7 @@ def compute_loss( supervision_segments = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -992,6 +999,8 @@ def compute_loss( loss += reconstruction_loss_scale * reconstruction_loss + loss += params.predict_loss_scale * predict_loss + if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -1011,6 +1020,7 @@ def compute_loss( info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.use_cr_ctc: info["cr_loss"] = cr_loss.detach().cpu().item() + info["predict_loss"] = predict_loss.detach().cpu().item() info["recon_loss"] = reconstruction_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 604c529126..4f17e33843 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -195,6 +195,10 @@ def set_downsample_factor(cur_downsample, ds): encoder.encoder_index = i # <-- will be used in streaming_forward encoders.append(encoder) + if downsampling_factor[i] == max(downsampling_factor): + self.predict_loss = PredictLoss(cur_downsample*input_dim, 256, batch_dim=1) + + cur_downsample = set_downsample_factor(cur_downsample, output_downsampling_factor) self.encoders = nn.ModuleList(encoders) @@ -265,6 +269,8 @@ def truncate(x, downsampling_factor): max_len = (orig_seq_len + downsampling_factor - 1) // downsampling_factor return x[:max_len] if x.shape[0] > max_len else x + max_ds = max(self.downsampling_factor) + for module in self.encoders: if isinstance(module, Zipformer2Encoder): i = module.encoder_index # was set in this class's __init__ function. @@ -283,6 +289,9 @@ def truncate(x, downsampling_factor): else attn_mask[::ds, ::ds] ), ) + if ds == max_ds: + predict_loss = self.predict_loss(x) + else: x = module(x) @@ -296,7 +305,7 @@ def truncate(x, downsampling_factor): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths + return x, lengths, predict_loss def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -525,8 +534,6 @@ def __init__( self.norm = ExpNorm(embed_dim) - self.predict_loss = PredictLoss(embed_dim, 256, loss_scale=0.01, batch_dim=1) - def forward( self, @@ -581,8 +588,6 @@ def forward( src = self.norm(src) - src = self.predict_loss(src) - return src def streaming_forward( @@ -2011,11 +2016,11 @@ def _test_zipformer_main(causal: bool = False): batch_size = 6 # make it even, as PredictLoss requires even batch size. seq_len = 21 # Just make sure the forward pass runs. - f = c( + f, lengths, predict_loss = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) - f[0].sum().backward() + f.sum().backward() c.eval() f = c( torch.randn(seq_len, batch_size, input_dim), From f858bb8e0a5cfde5637795ea12e6d0baf8862c3a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Apr 2025 21:51:28 +0800 Subject: [PATCH 0314/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index f94772a477..85fc14a512 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -582,7 +582,7 @@ def predict_loss(x: Tensor, pred_weight: Tensor, proj_weight: Tensor, batch_dim: if random.random() < 0.002: logging.info(f"name={name}, mean loss before scale = {loss.mean()}") - return loss + return loss.sum() # we reduce with sum in what we return. class PredictLoss(nn.Module): """ From 954473e198a8af99dc8fd6a393d4e938d4d6d2d6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Apr 2025 23:36:05 +0800 Subject: [PATCH 0315/1191] Rework PredictModule with mean normalization, different way of getting codes. --- egs/librispeech/ASR/zipformer/scaling.py | 35 +++++++++++++--------- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 85fc14a512..c256e0cacd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,19 +565,25 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, pred_weight: Tensor, proj_weight: Tensor, batch_dim: int, name: str) -> Tensor: +def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, batch_dim: int, name: str) -> Tensor: batch_size = x.shape[batch_dim] assert batch_size % 2 == 0, "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." - # is_top1 true if this is the highest dot-product with x_proj. - x_proj = torch.matmul(x, proj_weight.t()) - top1_indexes = torch.max(x_proj, dim=-1, keepdim=True)[1] - top1_indexes = torch.roll(top1_indexes, batch_size // 2, batch_dim) + with torch.no_grad(): + x_proj = torch.matmul(x, proj_weight.t()) + # subtract mean. + x_proj = x_proj - x_proj.mean(dim=tuple(range(0, x.ndim - 1))) + codes = (x_proj > 0).to(torch.int64) # codes: (..., 8), all between 0 and 1 + codes = codes * (2 ** torch.arange(8, device=x.device)) # multiply codes by (1, 2, 4, 8, ..) + indexes = codes.sum(dim=-1, keepdim=True) + + + indexes = torch.roll(indexes, batch_size // 2, batch_dim) - x_pred = torch.matmul(x, pred_weight.t()) + x_pred = predictor(x) logprobs = x_pred.log_softmax(dim=-1) - loss = -torch.gather(logprobs, dim=-1, index=top1_indexes) + loss = -torch.gather(logprobs, dim=-1, index=indexes) if random.random() < 0.002: logging.info(f"name={name}, mean loss before scale = {loss.mean()}") @@ -586,27 +592,28 @@ def predict_loss(x: Tensor, pred_weight: Tensor, proj_weight: Tensor, batch_dim: class PredictLoss(nn.Module): """ - Adds an auxiliary loss based on predicting the top-1 of inner-products with - a random set of vectors, of the "other copy" of each utterance (it assumes - you are doing something like CR-CTC so + Adds an auxiliary loss based on predicting the top-1 of 256 randomized codebook + entries. """ def __init__(self, num_channels: int, - num_centers: int, batch_dim: int = 0): super().__init__() scale = num_channels ** -0.5 self.register_buffer('proj_weight', - scale * torch.randn(num_centers, num_channels), + scale * torch.randn(8, num_channels), persistent=True) - self.pred_weight = nn.Parameter(scale * torch.randn(num_centers, num_channels)) + num_hidden = max(1024, num_channels) + self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), + nn.ReLU(), + nn.Linear(num_hidden, 256)) self.batch_dim = batch_dim self.name = None # will be set from training code def forward(self, x: Tensor) -> Tensor: - return predict_loss(x, self.pred_weight, self.proj_weight, + return predict_loss(x, self.predictor, self.proj_weight, self.batch_dim, self.name) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4f17e33843..d28aaf014f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -196,7 +196,7 @@ def set_downsample_factor(cur_downsample, ds): encoders.append(encoder) if downsampling_factor[i] == max(downsampling_factor): - self.predict_loss = PredictLoss(cur_downsample*input_dim, 256, batch_dim=1) + self.predict_loss = PredictLoss(cur_downsample*input_dim, batch_dim=1) cur_downsample = set_downsample_factor(cur_downsample, output_downsampling_factor) From e56a3b8e79de8622a923e452c9cd9b7d7c4d2a32 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Apr 2025 16:37:17 +0800 Subject: [PATCH 0316/1191] Fix predict_loss to work in test mode. --- egs/librispeech/ASR/zipformer/scaling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c256e0cacd..c9169f3b69 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -567,7 +567,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, batch_dim: int, name: str) -> Tensor: batch_size = x.shape[batch_dim] - assert batch_size % 2 == 0, "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." + if batch_size % 2 != 0: + assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." + return torch.tensor(0.0, device=x.device) with torch.no_grad(): x_proj = torch.matmul(x, proj_weight.t()) From c37c44b373d426cf2b364d796d9b864f84027ce9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Apr 2025 18:40:05 +0800 Subject: [PATCH 0317/1191] Do predict_loss at end of each Zipformer2Encoder. --- egs/librispeech/ASR/zipformer/zipformer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d28aaf014f..b9cd644e32 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -195,9 +195,6 @@ def set_downsample_factor(cur_downsample, ds): encoder.encoder_index = i # <-- will be used in streaming_forward encoders.append(encoder) - if downsampling_factor[i] == max(downsampling_factor): - self.predict_loss = PredictLoss(cur_downsample*input_dim, batch_dim=1) - cur_downsample = set_downsample_factor(cur_downsample, output_downsampling_factor) @@ -269,14 +266,15 @@ def truncate(x, downsampling_factor): max_len = (orig_seq_len + downsampling_factor - 1) // downsampling_factor return x[:max_len] if x.shape[0] > max_len else x - max_ds = max(self.downsampling_factor) + + predict_loss = 0.0 for module in self.encoders: if isinstance(module, Zipformer2Encoder): i = module.encoder_index # was set in this class's __init__ function. ds = self.downsampling_factor[i] x = truncate(x, ds) - x = module( + x, this_pred_loss = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -289,8 +287,7 @@ def truncate(x, downsampling_factor): else attn_mask[::ds, ::ds] ), ) - if ds == max_ds: - predict_loss = self.predict_loss(x) + predict_loss += this_pred_loss * ds else: x = module(x) @@ -305,7 +302,7 @@ def truncate(x, downsampling_factor): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths, predict_loss + return x, lengths, predict_loss / len(self.downsampling_factor) def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -759,6 +756,8 @@ def __init__( grad_scale=0.025, ) + self.predict_loss = PredictLoss(dim, batch_dim=1) + def forward( self, @@ -809,7 +808,7 @@ def forward( bypass = self.norm_bypass(bypass) src = torch.cat((src, bypass), dim=-1) - return src + return src, self.predict_loss(src) def streaming_forward( self, From ad0fd41ca94e8058c05c91a5823e1dab5ec949e5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Apr 2025 18:41:24 +0800 Subject: [PATCH 0318/1191] Bug fix, divide by output_downsampling_factor and remove old code to truncate channels --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b9cd644e32..39e60a8379 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -287,13 +287,11 @@ def truncate(x, downsampling_factor): else attn_mask[::ds, ::ds] ), ) - predict_loss += this_pred_loss * ds + predict_loss += this_pred_loss * (ds / self.output_downsampling_factor) else: x = module(x) - x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 From 4b361133e67dc948643090b184d91087e53a2ff3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Apr 2025 20:30:36 +0800 Subject: [PATCH 0319/1191] Use the src_key_padding_mask --- egs/librispeech/ASR/zipformer/scaling.py | 22 ++++++++++++++-------- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c9169f3b69..64f8285191 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,7 +565,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, batch_dim: int, name: str) -> Tensor: +def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, + batch_dim: int, name: str, + mask: Optional[Tensor]) -> Tensor: batch_size = x.shape[batch_dim] if batch_size % 2 != 0: assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." @@ -573,22 +575,26 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, batch_dim with torch.no_grad(): x_proj = torch.matmul(x, proj_weight.t()) + if mask is not None: + x_proj = x_proj - (x_proj * mask).sum(dim=tuple(range(0, x.ndim - 1))) / mask.sum(dim=tuple(range(0, x.ndim - 1))) + else: + x_proj = x_proj - x_proj.mean(dim=tuple(range(0, x.ndim - 1))) + # subtract mean. - x_proj = x_proj - x_proj.mean(dim=tuple(range(0, x.ndim - 1))) codes = (x_proj > 0).to(torch.int64) # codes: (..., 8), all between 0 and 1 codes = codes * (2 ** torch.arange(8, device=x.device)) # multiply codes by (1, 2, 4, 8, ..) indexes = codes.sum(dim=-1, keepdim=True) - indexes = torch.roll(indexes, batch_size // 2, batch_dim) - - x_pred = predictor(x) logprobs = x_pred.log_softmax(dim=-1) loss = -torch.gather(logprobs, dim=-1, index=indexes) if random.random() < 0.002: - logging.info(f"name={name}, mean loss before scale = {loss.mean()}") + logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") + + if mask is not None: + loss = loss * mask.to(loss.dtype) return loss.sum() # we reduce with sum in what we return. @@ -614,9 +620,9 @@ def __init__(self, def forward(self, - x: Tensor) -> Tensor: + x: Tensor, mask: Optional[Tensor] = None) -> Tensor: return predict_loss(x, self.predictor, self.proj_weight, - self.batch_dim, self.name) + self.batch_dim, self.name, mask) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 39e60a8379..d551073998 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -806,7 +806,8 @@ def forward( bypass = self.norm_bypass(bypass) src = torch.cat((src, bypass), dim=-1) - return src, self.predict_loss(src) + return src, self.predict_loss(src, (src_key_padding_mask.t().unsqueeze(-1).logical_not() + if src_key_padding_mask is not None else None)) def streaming_forward( self, From dafb421cb3d146c788349d73fe8ed859790ac46d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Apr 2025 20:31:11 +0800 Subject: [PATCH 0320/1191] Bug fix RE mask --- egs/librispeech/ASR/zipformer/scaling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 64f8285191..a0777d7ffd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -573,6 +573,8 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) + mask = mask.to(x.dtype) + with torch.no_grad(): x_proj = torch.matmul(x, proj_weight.t()) if mask is not None: @@ -594,7 +596,7 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") if mask is not None: - loss = loss * mask.to(loss.dtype) + loss = loss * mask return loss.sum() # we reduce with sum in what we return. From b2bfd196d9cfcdf6c764f01cc386baef2226b365 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Apr 2025 20:06:46 +0800 Subject: [PATCH 0321/1191] Extend Residual so scales can add to more than one. --- egs/librispeech/ASR/zipformer/scaling.py | 3 +- egs/librispeech/ASR/zipformer/zipformer.py | 55 ++++++++-------------- 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8995b534b6..8224e837b8 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -573,7 +573,8 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) - mask = mask.to(x.dtype) + if mask is not None: + mask = mask.to(x.dtype) with torch.no_grad(): x_proj = torch.matmul(x, proj_weight.t()) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ab92e7763d..a0e89f9e2b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -890,54 +890,37 @@ class ResidualModule(nn.Module): """ def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0, + self, + embed_dim: int, + function_scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), ): super().__init__() - self.direct_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_direct_scale(self, batch_size: int): - # returns scale of shape (num_channels,), - # or (batch_size, num_channels,). This is the - # scale on the non-residual term, and 1-direct_scale is the - # scale on the residual. + self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.subtract_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.function_scale_min = copy.deepcopy(function_scale_min) + + + def _get_scales(self, batch_size: int): if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.direct_scale + return 1.0 - (self.subtract_scale * self.function_scale), self.function_scale else: - ans = limit_param_value( - self.direct_scale, min=float(self.scale_min), max=float(self.scale_max) + function_scale = limit_param_value( + self.function_scale, min=float(self.function_scale_min), max=1.0, ) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = ( - torch.rand((batch_size, 1), device=ans.device) - < straight_through_rate - ) - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans + subtract_scale = limit_param_value( + self.subtract_scale, min=0.0, max=1.0, + ) + residual_scale = 1.0 - (subtract_scale * function_scale) + return residual_scale, function_scale def forward(self, src_orig: Tensor, src: Tensor): """ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) Returns: something with the same shape as src and src_orig """ - direct_scale = self._get_direct_scale(src.shape[1]) - return src_orig + (src - src_orig) * direct_scale + residual_scale, function_scale = self._get_scales(src.shape[1]) + return residual_scale * src_orig + function_scale * src From ced4a8141dad4dc1e83c810125ac1556b31d1275 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Apr 2025 12:41:43 +0800 Subject: [PATCH 0322/1191] Have 256 independent codebooks, mean and variance norm. --- egs/librispeech/ASR/zipformer/scaling.py | 26 +++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8224e837b8..ccba8e307f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -577,21 +577,27 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, mask = mask.to(x.dtype) with torch.no_grad(): + # get the indexes. project, then mean-and-variance-norm, then + # take mx. x_proj = torch.matmul(x, proj_weight.t()) - if mask is not None: - x_proj = x_proj - (x_proj * mask).sum(dim=tuple(range(0, x.ndim - 1))) / mask.sum(dim=tuple(range(0, x.ndim - 1))) - else: - x_proj = x_proj - x_proj.mean(dim=tuple(range(0, x.ndim - 1))) + with torch.cuda.amp.autocast(enabled=False): + x_proj = x_proj.to(torch.float) + # Mean subtraction and variance normalization. + dims = tuple(range(0, x.ndim - 1)) + if mask is not None: + x_masked = x_proj * mask + x_proj = x_proj - x_masked.sum(dim=dims) / mask.sum(dim=dims) + x_proj = x_proj * (mask.sum(dim=dims) / ((x_masked ** 2).sum(dim=dims) + 1.0e-10)).sqrt() + else: + x_proj = x_proj - x_proj.mean(dim=dims) + x_proj = x_proj / (x_proj ** 2).mean(dim=dims).sqrt() - # subtract mean. - codes = (x_proj > 0).to(torch.int64) # codes: (..., 8), all between 0 and 1 - codes = codes * (2 ** torch.arange(8, device=x.device)) # multiply codes by (1, 2, 4, 8, ..) - indexes = codes.sum(dim=-1, keepdim=True) + indexes = torch.max(x_proj, dim=-1)[1] indexes = torch.roll(indexes, batch_size // 2, batch_dim) x_pred = predictor(x) logprobs = x_pred.log_softmax(dim=-1) - loss = -torch.gather(logprobs, dim=-1, index=indexes) + loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) if random.random() < 0.002: logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") @@ -612,7 +618,7 @@ def __init__(self, super().__init__() scale = num_channels ** -0.5 self.register_buffer('proj_weight', - scale * torch.randn(8, num_channels), + scale * torch.randn(256, num_channels), persistent=True) num_hidden = max(1024, num_channels) self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), From 027f5cefa4e9c4e61b231f22afc5c3f7263ce7f5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Apr 2025 12:42:41 +0800 Subject: [PATCH 0323/1191] Use LeakyReLU in PredictLoss --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ccba8e307f..26c91704b9 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -622,7 +622,7 @@ def __init__(self, persistent=True) num_hidden = max(1024, num_channels) self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), - nn.ReLU(), + nn.LeakyReLU(), nn.Linear(num_hidden, 256)) self.batch_dim = batch_dim self.name = None # will be set from training code From f44fa9a89cec06ab53da7d82815d54dcd6c80917 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Apr 2025 11:03:43 +0800 Subject: [PATCH 0324/1191] Simplify Residual (no subtract_scale); refactor OrthogonalLinear so we can in future make penalty_scale a schedule. --- egs/librispeech/ASR/zipformer/scaling.py | 21 +++++++++++++-------- egs/librispeech/ASR/zipformer/zipformer.py | 20 +++++++------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 26c91704b9..8744e12d9c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -638,12 +638,14 @@ def forward(self, class OrthogonalLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, name, in_groups, out_groups, group_size): + def forward(ctx, x: Tensor, weight: Tensor, name: str, in_groups: int, + out_groups: int, group_size: int, penalty_scale: float): ctx.save_for_backward(x, weight) ctx.name = name ctx.out_groups = out_groups ctx.in_groups = in_groups ctx.group_size = group_size + ctx.penalty_scale = penalty_scale assert not (in_groups > 0 and out_groups > 0) return torch.matmul(x, weight.t()) @@ -657,14 +659,13 @@ def backward(ctx, y_grad): else: x_grad = None - out_groups, in_groups, group_size = ctx.out_groups, ctx.in_groups, ctx.group_size if weight.requires_grad: weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), x.reshape(-1, x.shape[-1])) - penalty_scale = 20.0 * weight_grad.abs().mean() + penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() with torch.enable_grad(): weight = weight.detach() @@ -718,7 +719,7 @@ def diag_inplace(z): prod.backward(gradient=prod * penalty_scale) - do_print = random.random() < 0.005 + do_print = random.random() < 0.002 if do_print: # we print a normalized version of the loss, by dividing by the # number of rows. @@ -730,7 +731,7 @@ def diag_inplace(z): weight_grad += weight.grad else: weight_grad = None - return x_grad, weight_grad, None, None, None, None + return x_grad, weight_grad, None, None, None, None, None @@ -760,7 +761,9 @@ class OrthogonalLinear(nn.Linear): bias: if True, include a bias term. initial_scale: a factor that allows you to increase or decrease the initial scale of the weight (and bias, if present) - + penalty_scale: a scale on the penalty on non-orthogonality (this will + be multiplied by the average-absolute-value of the + backpropagated gradient). """ # if in_groups or out_groups are set to >1, the orthogonal constraint # will be set per group. both of them cannot be >1. @@ -772,6 +775,7 @@ def __init__(self, group_size: int = -1, bias: bool = True, initial_scale: float = 1.0, + penalty_scale: FloatLike = 20.0, ): super().__init__(in_channels, out_channels, bias=bias) self.name = None @@ -782,6 +786,7 @@ def __init__(self, elif out_groups > 0 and group_size == -1: group_size = out_channels // out_groups self.group_size = group_size + self.penalty_scale = copy.deepcopy(penalty_scale) # the same scaling as for ScaledLinear. with torch.no_grad(): @@ -796,7 +801,7 @@ def forward(self, x: Tensor): ans = OrthogonalLinearFunction.apply(x, self.weight, self.name, self.in_groups, self.out_groups, - self.group_size) + self.group_size, float(self.penalty_scale)) if self.bias is not None: ans = ans + self.bias return ans @@ -1197,7 +1202,7 @@ def backward(ctx, y_grad: Tensor): # (x**2).mean() > 1.0, but it starts of small if we are close to 1.0 # so we don't suddenly add large gradients that could be destabilizing. eps = 0.01 - loss_scale = eps * ((x ** 2).mean() - ctx.max_scale).relu() + loss_scale = eps * ((x ** 2).mean() - ctx.max_scale).relu() # caution: this is a bug, there is no sqrt(). y_grad_rms = (y_grad ** 2).mean().sqrt() # y_grad_rms is a scaling factor for the gradient contribution, since we # don't know at this point the total scale of the main loss. diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a0e89f9e2b..34f244ff49 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -896,30 +896,24 @@ def __init__( ): super().__init__() self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.subtract_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.function_scale_min = copy.deepcopy(function_scale_min) - def _get_scales(self, batch_size: int): - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return 1.0 - (self.subtract_scale * self.function_scale), self.function_scale - else: + def _get_scales(self): + function_scale = self.function_scale + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: function_scale = limit_param_value( - self.function_scale, min=float(self.function_scale_min), max=1.0, - ) - subtract_scale = limit_param_value( - self.subtract_scale, min=0.0, max=1.0, + function_scale, min=float(self.function_scale_min), max=1.0, ) - residual_scale = 1.0 - (subtract_scale * function_scale) - - return residual_scale, function_scale + residual_scale = 1.0 - function_scale + return residual_scale, function_scale def forward(self, src_orig: Tensor, src: Tensor): """ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) Returns: something with the same shape as src and src_orig """ - residual_scale, function_scale = self._get_scales(src.shape[1]) + residual_scale, function_scale = self._get_scales() return residual_scale * src_orig + function_scale * src From 2b711fef94858baf46a1cf33886e18ed569bf79a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Apr 2025 11:09:56 +0800 Subject: [PATCH 0325/1191] Make OrthogonalUpsample orthogonality penalty disappear. --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- egs/librispeech/ASR/zipformer/zipformer.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8744e12d9c..7aeffd4ca3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -664,7 +664,10 @@ def backward(ctx, y_grad): if weight.requires_grad: weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), x.reshape(-1, x.shape[-1])) + else: + weight_grad = None + if weight.requires_grad and ctx.penalty_scale != 0.0: penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() with torch.enable_grad(): @@ -729,8 +732,6 @@ def diag_inplace(z): # add the extra gradient term from the orthogonality loss. weight_grad += weight.grad - else: - weight_grad = None return x_grad, weight_grad, None, None, None, None, None diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 34f244ff49..25c9c7d7b8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -981,7 +981,9 @@ class OrthogonalUpsample(torch.nn.Module): def __init__(self, channels: int, proj_dim: int): super().__init__() assert proj_dim <= channels - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # gradually make smaller and then turn off the non-orthognality penalty. + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. # it will be interpreted by get_parameter_groups_with_lrs() self.proj.lr_scale = 0.75 From 651b0ebc180a1d89d791d30cde79d85a9fe395fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Apr 2025 11:33:37 +0800 Subject: [PATCH 0326/1191] Fixes to ScaleLimiter to avoid infinities in backprop and for cosmetics; reduce self_attn in_proj initialization. --- egs/librispeech/ASR/zipformer/scaling.py | 23 +++++++++++----------- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8744e12d9c..e62e975a57 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1190,9 +1190,9 @@ def _approx_inverse_erf(x): class ScaleLimiterFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, max_scale: float): + def forward(ctx, x: Tensor, max_var: float): ctx.save_for_backward(x) - ctx.max_scale = max_scale + ctx.max_var = max_var return x @staticmethod @@ -1202,26 +1202,27 @@ def backward(ctx, y_grad: Tensor): # (x**2).mean() > 1.0, but it starts of small if we are close to 1.0 # so we don't suddenly add large gradients that could be destabilizing. eps = 0.01 - loss_scale = eps * ((x ** 2).mean() - ctx.max_scale).relu() # caution: this is a bug, there is no sqrt(). - y_grad_rms = (y_grad ** 2).mean().sqrt() - # y_grad_rms is a scaling factor for the gradient contribution, since we + loss_scale = eps * ((x.to(torch.float) ** 2).mean() - ctx.max_var).relu() + y_grad_abs_mean = y_grad.abs().mean() + # y_grad_abs_mean is a scaling factor for the gradient contribution, since we # don't know at this point the total scale of the main loss. # the grad of (x ** 2).mean() would be 2 * x. we absorb the factor of 2 # into eps, which is just an arbitrary smallish value. - return y_grad + (loss_scale * y_grad_rms) * x, None + return y_grad + (loss_scale * y_grad_abs_mean) * x, None class ScaleLimiter(torch.nn.Module): """ - Tries to make the rms value of the features no greater than self.max_scale, by + Tries to make the average square value of the features no greater than self.max_var, by adding a penalty. This is not per dimension, but globally. Assumes channel dim is -1 and the input shape has >1 dimension. + Caution: max_var is actually a maximum variance. """ - def __init__(self, max_scale: FloatLike = 1.0, prob: FloatLike = 1.0): + def __init__(self, max_var: FloatLike = 1.0, prob: FloatLike = 1.0): super().__init__() self.name = None - self.max_scale = max_scale + self.max_var = max_var self.prob = prob def forward(self, x: Tensor) -> Tensor: @@ -1232,10 +1233,10 @@ def forward(self, x: Tensor) -> Tensor: # (x ** 2).mean() > 1.0, the penalty will tend to reduce the value # of (x ** 2). if random.random() < 0.001: - logging.info(f"name={self.name}, max_scale={float(self.max_scale)}, prob={float(self.prob)}, x_rms={(x**2).mean().sqrt().item()}") + logging.info(f"name={self.name}, max_var={float(self.max_var)}, prob={float(self.prob)}, x_rms={(x**2).mean().sqrt().item()}") prob = float(self.prob) if prob > 0 and random.random() < prob: - return ScaleLimiterFunction.apply(x, float(self.max_scale)) + return ScaleLimiterFunction.apply(x, float(self.max_var)) else: return x diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 34f244ff49..6a37530ecb 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -525,7 +525,7 @@ def __init__( self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) for _ in range(2) ] - self.scale_limiter = ScaleLimiter(max_scale=2.0) + self.scale_limiter = ScaleLimiter(max_var=2.0) self.norm = ExpNorm(embed_dim) @@ -1176,7 +1176,7 @@ def __init__( # it would be necessary to apply the scaling factor in the forward function. self.in_proj = ScaledLinear( embed_dim, in_proj_dim, - bias=True, initial_scale=query_head_dim**-0.25 + bias=True, initial_scale=0.5 * query_head_dim**-0.25 ) self.whiten_keys = Whiten( From 841a1a39acee954a9de1163a8745144b7cc74cbc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Apr 2025 17:28:12 +0800 Subject: [PATCH 0327/1191] Reduce end-batch of function_scale_min schedule from 20k to 4k. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 5e2551d4e8..2398333552 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -892,7 +892,7 @@ class ResidualModule(nn.Module): def __init__( self, embed_dim: int, - function_scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + function_scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (4000.0, 0.2), default=0), ): super().__init__() self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) From 42f508ec3de165f56d635e282087a9f36c673106 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 19 Apr 2025 13:55:35 +0800 Subject: [PATCH 0328/1191] Reduce function_scale_min in ResidualModule to a constant of 0.1 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2398333552..8b1a38c8e5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -892,7 +892,7 @@ class ResidualModule(nn.Module): def __init__( self, embed_dim: int, - function_scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (4000.0, 0.2), default=0), + function_scale_min: FloatLike = 0.1, ): super().__init__() self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) From 02c55154d8945c815ba50669bbffbb373e0a93e4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Apr 2025 20:33:01 +0800 Subject: [PATCH 0329/1191] Reduce frontend dropout from .1 (final) to .0 --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 47a28b098f..7dec291f5f 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -677,7 +677,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, out_channels=params.embed_dim, - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + dropout=0.0, ) return encoder_embed From 8cf2a2ce25dcfeb176a6e46dde866ce069a4db45 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Apr 2025 12:28:26 +0800 Subject: [PATCH 0330/1191] Decrease codebook_size from 256 to 64. --- egs/librispeech/ASR/zipformer/scaling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index dbf35250f6..03ecd156f8 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -614,16 +614,17 @@ class PredictLoss(nn.Module): """ def __init__(self, num_channels: int, - batch_dim: int = 0): + batch_dim: int = 0, + codebook_size: int = 64): super().__init__() scale = num_channels ** -0.5 self.register_buffer('proj_weight', - scale * torch.randn(256, num_channels), + scale * torch.randn(codebook_size, num_channels), persistent=True) num_hidden = max(1024, num_channels) self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), nn.LeakyReLU(), - nn.Linear(num_hidden, 256)) + nn.Linear(num_hidden, codebook_size)) self.batch_dim = batch_dim self.name = None # will be set from training code From 332ae2af5abdc3be6ea881c6cb22835cd524c9d7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Apr 2025 22:34:22 +0800 Subject: [PATCH 0331/1191] Change codebook-size from 64 to 63 to avoid a bug in torch. https://github.com/pytorch/pytorch/issues/152017 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 03ecd156f8..c0b1946b4e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -615,7 +615,7 @@ class PredictLoss(nn.Module): def __init__(self, num_channels: int, batch_dim: int = 0, - codebook_size: int = 64): + codebook_size: int = 63): super().__init__() scale = num_channels ** -0.5 self.register_buffer('proj_weight', From d59e2c64643810d63f2598552cd9588563426afa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Apr 2025 16:56:32 +0800 Subject: [PATCH 0332/1191] Introduce base-dim to zipformer --- egs/librispeech/ASR/zipformer/model.py | 4 +- egs/librispeech/ASR/zipformer/train.py | 151 +++++++++++++-------- egs/librispeech/ASR/zipformer/zipformer.py | 17 +-- 3 files changed, 106 insertions(+), 66 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index a8fdbd5e4d..34f9146810 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -509,10 +509,10 @@ def forward_reconstruction_loss(self, log_mels: log-mel features of shape (batch_size, T, num_mels) encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) """ + batch_size = log_mels.shape[0] + num_mels = log_mels.shape[2] if use_cr_ctc: - batch_size = log_mels.shape[0] log_mels = torch.roll(log_mels, batch_size // 2, dims=0) - num_mels = log_mels.shape[2] pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) T_embed = pred_mels.shape[1] diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7dec291f5f..86e37388cd 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -82,7 +82,7 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 +from zipformer2 import Zipformer2 from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -144,47 +144,91 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: module.name = name +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="2,2,3,4,3,2", + default="3,4,4,4,4,4,4,4,4", help="Number of zipformer encoder layers per stack, comma separated.", ) parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,8,4,2", + default="1,2,4,4,8,8,4,4,2", help="Downsampling factor for each stack of encoder layers.", ) parser.add_argument( - "--embed-dim", + "--base-dim", type=int, - default=192, - help="Output dimension of frontend, also determines bypass dimensions in zipformer layers.", + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." ) parser.add_argument( - "--feedforward-dim", + "--embed-multiple", + type=int, + default=4, + help="Output dimension of frontend, as multiple of base-dim; bypass dimensions in zipformer layers.", + ) + + parser.add_argument( + "--feedforward-multiple", type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + default="4,4,4,4,4,4,4,4,4", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", ) parser.add_argument( "--num-heads", type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + default="4,4,4,4,8,8,4,4,4", + help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) parser.add_argument( - "--encoder-dim", + "--encoder-multiple", type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + default="3,4,6,6,8,8,6,6,4", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) parser.add_argument( @@ -218,22 +262,22 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31,31,15,15,15,31", + default="31,31,15,15,15,15,15,15,31", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) parser.add_argument( - "--decoder-dim", + "--decoder-multiple", type=int, - default=512, - help="Embedding dimension in the decoder model.", + default=8, + help="Factor by which embedding dimension in the decoder model is larger than base-dim.", ) parser.add_argument( - "--joiner-dim", + "--joiner-multiple", type=int, - default=512, + default=4, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. @@ -241,10 +285,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--attention-decoder-dim", + "--attention-decoder-multiple", type=int, - default=512, - help="""Dimension used in the attention decoder""", + default=8, + help="""Factor by which attention decoder dim is larger than base-dim""", ) parser.add_argument( @@ -255,10 +299,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--attention-decoder-attention-dim", + "--attention-decoder-attention-multiple", type=int, - default=512, - help="""Attention dimension used in attention decoder""", + default=8, + help="""Determines attention dimension used in attention decoder""", ) parser.add_argument( @@ -269,10 +313,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--attention-decoder-feedforward-dim", + "--attention-decoder-feedforward-multiple", type=int, - default=2048, - help="""Feedforward dimension used in attention decoder""", + default=4, + help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" ) parser.add_argument( @@ -676,7 +720,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: # sampling rate. encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, - out_channels=params.embed_dim, + out_channels=lookup(params, "embed_dim"), dropout=0.0, ) return encoder_embed @@ -684,23 +728,22 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer2( - input_dim=params.embed_dim, + input_dim=lookup(params, "embed_dim"), output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + pos_head_dim=lookup(params, "pos_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), - warmup_batches=4000.0, + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + cnn_module_kernel=lookup(params, "cnn_module_kernel"), + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), + chunk_size=lookup(params, "chunk_size"), + left_context_frames=lookup(params, "left_context_frames"), ) return encoder @@ -708,7 +751,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, + decoder_dim=lookup(params, "decoder_dim"), blank_id=params.blank_id, context_size=params.context_size, ) @@ -718,9 +761,9 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: output_downsampling_factor = 2 joiner = Joiner( - encoder_dim=params.embed_dim * output_downsampling_factor, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, + encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, + decoder_dim=lookup(params, "decoder_dim"), + joiner_dim=lookup(params, "joiner_dim"), vocab_size=params.vocab_size, ) return joiner @@ -729,12 +772,12 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_attention_decoder_model(params: AttributeDict) -> nn.Module: decoder = AttentionDecoderModel( vocab_size=params.vocab_size, - decoder_dim=params.attention_decoder_dim, + decoder_dim=lookup(params, "attention_decoder_dim"), num_decoder_layers=params.attention_decoder_num_layers, - attention_dim=params.attention_decoder_attention_dim, + attention_dim=lookup(params, "attention_decoder_attention_dim"), num_heads=params.attention_decoder_num_heads, - feedforward_dim=params.attention_decoder_feedforward_dim, - memory_dim=params.embed_dim * output_downsampling_factor, + feedforward_dim=lookup(params, "attention_decoder_feedforward_dim"), + memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, sos_id=params.sos_id, eos_id=params.eos_id, ignore_id=params.ignore_id, @@ -772,8 +815,8 @@ def get_model(params: AttributeDict) -> nn.Module: decoder=decoder, joiner=joiner, attention_decoder=attention_decoder, - encoder_dim=output_downsampling_factor * params.embed_dim, - decoder_dim=params.decoder_dim, + encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), + decoder_dim=lookup(params, "decoder_dim"), vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8b1a38c8e5..d0e580d363 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -70,15 +70,13 @@ class Zipformer2(EncoderInterface): value_head_dim (int or Tuple[int]): dimension of value in each attention head num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module pos_dim (int): the dimension of each positional-encoding vector prior to projection, e.g. 128. dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. causal (bool): if True, support chunkwise causal convolution. This should not hurt WER as no modeling power is lost, but the convolution modules will be slightly slower and use more memory. Enables use of the chunk_size and @@ -102,11 +100,10 @@ def __init__( pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, + feedforward_multiple: Union[int, Tuple[int]] = 4, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], @@ -137,7 +134,7 @@ def _to_tuple(x): self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) pos_head_dim = _to_tuple(pos_head_dim) self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) + feedforward_multiple = _to_tuple(feedforward_multiple) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) self.causal = causal @@ -178,7 +175,7 @@ def set_downsample_factor(cur_downsample, ds): query_head_dim=query_head_dim[i], pos_head_dim=pos_head_dim[i], value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], + feedforward_multiple=feedforward_multiple[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], causal=causal, @@ -470,7 +467,7 @@ class Zipformer2EncoderLayer(nn.Module): Args: embed_dim: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (required). + feedforward_multiple: determines the hidden dimension of the feedforward module dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module (default=31). @@ -488,7 +485,7 @@ def __init__( query_head_dim: int, pos_head_dim: int, value_head_dim: int, - feedforward_dim: int, + feedforward_multiple: int, dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, @@ -515,9 +512,9 @@ def __init__( self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] + feedforward_dim = embed_dim * feedforward_multiple self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) From c3f8d673d6d144eb347983daaf52286aa65dba9a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Apr 2025 17:45:14 +0800 Subject: [PATCH 0333/1191] Fix wrong import. --- egs/librispeech/ASR/zipformer/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 86e37388cd..8d0105e0fa 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -82,7 +82,7 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer2 +from zipformer import Zipformer2 from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -277,7 +277,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--joiner-multiple", type=int, - default=4, + default=8, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. @@ -776,7 +776,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module: num_decoder_layers=params.attention_decoder_num_layers, attention_dim=lookup(params, "attention_decoder_attention_dim"), num_heads=params.attention_decoder_num_heads, - feedforward_dim=lookup(params, "attention_decoder_feedforward_dim"), + feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, sos_id=params.sos_id, eos_id=params.eos_id, From 968bc0cf46099d4880d6200c248d2a9fc487b957 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Apr 2025 18:50:59 +0800 Subject: [PATCH 0334/1191] Fix printing AttributeDict for dtype --- icefall/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/utils.py b/icefall/utils.py index 83e8106322..e69ab8cd05 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -186,7 +186,7 @@ def __str__(self, indent: int = 2): tmp = {} for k, v in self.items(): # PosixPath is ont JSON serializable - if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + if isinstance(v, (pathlib.Path, torch.device, torch.dtype)): v = str(v) tmp[k] = v return json.dumps(tmp, indent=indent, sort_keys=True) From 2731e1ab50d1de9c2544825badb95991dd521e67 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Apr 2025 11:10:18 +0800 Subject: [PATCH 0335/1191] Bug-fix for crash in ctc_loss --- egs/librispeech/ASR/zipformer/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 34f9146810..0e4f1f32c8 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -184,9 +184,9 @@ def forward_ctc( ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) return ctc_loss @@ -212,9 +212,9 @@ def forward_cr_ctc( ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) From e1f6beb84c76e19423e09e6e0750daa1c6ee7877 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Apr 2025 11:42:53 +0800 Subject: [PATCH 0336/1191] Decrease initialization of self_attn in_proj --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d0e580d363..231bed93aa 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1175,7 +1175,7 @@ def __init__( # it would be necessary to apply the scaling factor in the forward function. self.in_proj = ScaledLinear( embed_dim, in_proj_dim, - bias=True, initial_scale=0.5 * query_head_dim**-0.25 + bias=True, initial_scale=0.25 * query_head_dim**-0.25 ) self.whiten_keys = Whiten( From 3cb2b09241ed6c46a1583cac4794879d90f79138 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Apr 2025 12:13:45 +0800 Subject: [PATCH 0337/1191] Fix comments --- egs/librispeech/ASR/zipformer/model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 0e4f1f32c8..36799e70d3 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -182,9 +182,20 @@ def forward_ctc( # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). This activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + targets=targets.long(), input_lengths=encoder_out_lens.long(), target_lengths=target_lengths.long(), reduction="sum", From f17b4ac9880f563864235f8f6fb7bc598841c4b7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Apr 2025 14:29:01 +0800 Subject: [PATCH 0338/1191] Increase encoder-multiple from 4 to 6. --- egs/librispeech/ASR/zipformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 8d0105e0fa..e3cbe72bba 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -206,8 +206,8 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--embed-multiple", type=int, - default=4, - help="Output dimension of frontend, as multiple of base-dim; bypass dimensions in zipformer layers.", + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", ) parser.add_argument( From 1e65fee090762baba925c5f75ff371dca9d38ecc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Apr 2025 18:51:05 +0800 Subject: [PATCH 0339/1191] Use warmup schedule for prediction loss; increase initial value of warmup schedule of reconstruction loss to 4.0; make the warmup schedules geometric; get rid of simple_loss_scale warmup. --- egs/librispeech/ASR/zipformer/train.py | 35 +++++++++++++------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index e3cbe72bba..d821c2a031 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -679,7 +679,9 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. + scale on pruned loss (for transducer) and the reconstruction and prediction + losses. Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. """ params = AttributeDict( { @@ -697,7 +699,7 @@ def get_params() -> AttributeDict: # parameters for attention-decoder "ignore_id": -1, "label_smoothing": 0.1, - "warm_step": 2000, + "warm_step": 4000, "env_info": get_env_info(), } ) @@ -995,7 +997,6 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train - warm_step = params.warm_step texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) @@ -1033,20 +1034,17 @@ def compute_loss( loss = 0.0 + adjusted_batch_count = params.batch_idx_train + warm_step = params.warm_step + def warmup_schedule(scale, initial_factor): + # geometric warmup schedules. + warmup_factor = (1. if adjusted_batch_count >= warm_step else + initial_factor ** (1. - (adjusted_batch_count / warm_step))) + return scale * warmup_factor + if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + simple_loss_scale = params.simple_loss_scale + pruned_loss_scale = warmup_schedule(1.0, 0.05) loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: @@ -1054,11 +1052,12 @@ def compute_loss( if use_cr_ctc: loss += params.cr_loss_scale * cr_loss - reconstruction_loss_scale = (params.reconstruction_loss_scale * - max(1.0, 2.0 - 1.0 * (batch_idx_train / warm_step))) + reconstruction_loss_scale = warmup_schedule(params.reconstruction_loss_scale, 4.0) loss += reconstruction_loss_scale * reconstruction_loss + predict_loss_scale = warmup_schedule(params.predict_loss_scale, 4.0) + loss += params.predict_loss_scale * predict_loss if params.use_attention_decoder: From e3c4936c2cc904bab2c6ced49d8afe50a847dc02 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Apr 2025 12:31:18 +0800 Subject: [PATCH 0340/1191] Make PredictLoss predictor and predicted different, let encoder output predict whole output. --- egs/librispeech/ASR/zipformer/scaling.py | 39 +++++++++++++--------- egs/librispeech/ASR/zipformer/zipformer.py | 10 ++++-- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c0b1946b4e..f14c7b4a42 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -565,7 +565,8 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, +def predict_loss(x: Tensor, y: Tensor, + predictor: nn.Module, proj_weight: Tensor, batch_dim: int, name: str, mask: Optional[Tensor]) -> Tensor: batch_size = x.shape[batch_dim] @@ -579,20 +580,20 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, with torch.no_grad(): # get the indexes. project, then mean-and-variance-norm, then # take mx. - x_proj = torch.matmul(x, proj_weight.t()) + y_proj = torch.matmul(y, proj_weight.t()) with torch.cuda.amp.autocast(enabled=False): - x_proj = x_proj.to(torch.float) + y_proj = y_proj.to(torch.float) # Mean subtraction and variance normalization. dims = tuple(range(0, x.ndim - 1)) if mask is not None: - x_masked = x_proj * mask - x_proj = x_proj - x_masked.sum(dim=dims) / mask.sum(dim=dims) - x_proj = x_proj * (mask.sum(dim=dims) / ((x_masked ** 2).sum(dim=dims) + 1.0e-10)).sqrt() + y_masked = y_proj * mask + y_proj = y_proj - y_masked.sum(dim=dims) / mask.sum(dim=dims) + y_proj = y_proj * (mask.sum(dim=dims) / ((y_masked ** 2).sum(dim=dims) + 1.0e-10)).sqrt() else: - x_proj = x_proj - x_proj.mean(dim=dims) - x_proj = x_proj / (x_proj ** 2).mean(dim=dims).sqrt() + y_proj = y_proj - y_proj.mean(dim=dims) + y_proj = y_proj / (y_proj ** 2).mean(dim=dims).sqrt() - indexes = torch.max(x_proj, dim=-1)[1] + indexes = torch.max(y_proj, dim=-1)[1] indexes = torch.roll(indexes, batch_size // 2, batch_dim) x_pred = predictor(x) @@ -611,18 +612,22 @@ class PredictLoss(nn.Module): """ Adds an auxiliary loss based on predicting the top-1 of 256 randomized codebook entries. + x_channels: the number of channels of the thing we want to learn + y_channels: the number of channels of the thing that generates the codebook + indexes to learn. No grad will be backpropagated to this. """ def __init__(self, - num_channels: int, + x_channels: int, + y_channels: int, batch_dim: int = 0, codebook_size: int = 63): super().__init__() - scale = num_channels ** -0.5 + scale = y_channels ** -0.5 self.register_buffer('proj_weight', - scale * torch.randn(codebook_size, num_channels), + scale * torch.randn(codebook_size, y_channels), persistent=True) - num_hidden = max(1024, num_channels) - self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), + num_hidden = max(1024, x_channels) + self.predictor = nn.Sequential(nn.Linear(x_channels, num_hidden), nn.LeakyReLU(), nn.Linear(num_hidden, codebook_size)) self.batch_dim = batch_dim @@ -630,8 +635,10 @@ def __init__(self, def forward(self, - x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - return predict_loss(x, self.predictor, self.proj_weight, + x: Tensor, + y: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + return predict_loss(x, y, self.predictor, self.proj_weight, self.batch_dim, self.name, mask) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 231bed93aa..fc053956d8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -750,7 +750,7 @@ def __init__( grad_scale=0.025, ) - self.predict_loss = PredictLoss(dim, batch_dim=1) + self.predict_loss = PredictLoss(encoder_layer.embed_dim, dim, batch_dim=1) def forward( @@ -795,6 +795,9 @@ def forward( # randomize_factor can be viewed as a simple version of an # importance-sampling factor. + + encoder_output = src + src = self.residual(src_orig, src) src = self.whiten(src) @@ -802,8 +805,9 @@ def forward( bypass = self.copy_bypass(bypass) src = torch.cat((src, bypass), dim=-1) - return src, self.predict_loss(src, (src_key_padding_mask.t().unsqueeze(-1).logical_not() - if src_key_padding_mask is not None else None)) + return src, self.predict_loss(encoder_output, src, + (src_key_padding_mask.t().unsqueeze(-1).logical_not() + if src_key_padding_mask is not None else None)) def streaming_forward( self, From e0a93273cc06a3d70df5a85f4fde878ff547c746 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Apr 2025 13:08:49 +0800 Subject: [PATCH 0341/1191] Halve self_attn_weight initial scale. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fc053956d8..4131d11e74 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1179,7 +1179,7 @@ def __init__( # it would be necessary to apply the scaling factor in the forward function. self.in_proj = ScaledLinear( embed_dim, in_proj_dim, - bias=True, initial_scale=0.25 * query_head_dim**-0.25 + bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) self.whiten_keys = Whiten( From ad8973fa6ee442749774ecff3432b17bdea3a2a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Apr 2025 18:58:47 +0800 Subject: [PATCH 0342/1191] Increase beta1 in optim.py to 0.99. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9186632839..5cc9902dcc 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -385,7 +385,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.98, + beta1=0.99, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scalar_lr_scale=0.1, From 18d02418792d24b386a24a8da6d85df738befa41 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Apr 2025 21:18:08 +0800 Subject: [PATCH 0343/1191] Increase beta1 further from .99 to .995 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 5cc9902dcc..380e790d3c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -385,7 +385,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.99, + beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scalar_lr_scale=0.1, From 7f66e6db9c81d412fd58ef63ae66bf1878f8d3b8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Apr 2025 13:26:18 +0800 Subject: [PATCH 0344/1191] Introduce the Sched3 optimizer. --- egs/librispeech/ASR/zipformer/optim.py | 91 +++++++++++++++++++++++++- egs/librispeech/ASR/zipformer/train.py | 6 +- 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 380e790d3c..158130002f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1009,7 +1009,6 @@ class Eden2(LRScheduler): where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam Args: @@ -1050,6 +1049,66 @@ def get_lr(self): +class Sched3(LRScheduler): + """ + Sched3 scheduler. + + The basic formula is as follows. p is a supplied power, e.g. 1.0, but could + also be, say, 0.8. lr_batches is a number of batches that defines when we start + decreasing significantly. "batch" is the current batch count. + + lr = warmup * min((lr_batches / batch)^p, exp(-batch / (e * lr_batches))) + + where e is the mathematical constant e. This expression is equivalent to: + min_q [ (q * lr_batches) / batch)^q ] where the minimum is taken over + the continuous range 0 <= q <= p. The left hand side of the min in the formula + for lr corresponds to q == p, i.e. we hit the rhs of the allowed range. + + `warmup` increases linearly from warmup_start to 1 over `warmup_batches` batches + and then stays constant at 1. + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + p: float = 1.0, + verbose: bool = False, + ): + super().__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.warmup_batches = warmup_batches + self.p = p + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + lr_batches = self.lr_batches + batch = max(self.batch, 0.1) # avoid division by zero + factor = min((lr_batches / batch) ** self.p, + 2.71828 ** (-batch / (2.71828 * lr_batches))) + + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + + + def _test_eden(): @@ -1077,6 +1136,30 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") +def _test_sched3(): + m = torch.nn.Linear(100, 100) + optim = TransformedAdam(m.parameters(), lr=0.03) + + scheduler = Sched3(optim, lr_batches=100, verbose=True, warmup_batches=20) + + + for step in range(200): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + if step % 10 == 0: + logging.info(f"test_sched3: step={step}, last lr = {scheduler.get_last_lr()}") + + logging.info(f"state dict = {scheduler.state_dict()}") + + # This is included mostly as a baseline for TransformedAdam. class Eve(Optimizer): """ @@ -1321,6 +1404,7 @@ def _test_scaled_adam(hidden_dim: int): logging.info(f"output_magnitudes = {output_magnitudes}") def _test_transform_params(): + # caution: this has occasional errors. group = { "bias_min_rms": 0.001, "weight_min_rms": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, "weight_max_rms": 20.0, "bias_max_rms": 20.0 } for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: @@ -1348,6 +1432,7 @@ def _test_transform_params(): else: hidden_dim = 200 - _test_transform_params() - _test_scaled_adam(hidden_dim) + #_test_transform_params() + #_test_scaled_adam(hidden_dim) _test_eden() + _test_sched3() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index d821c2a031..4c1cf38737 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel -from optim import Eden2, TransformedAdam +from optim import Sched3, TransformedAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -118,7 +118,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: def get_adjusted_lr_batches(params: AttributeDict) -> float: # returns an adjusted form of the "lr_batches" parameter used to set the learning - # rate in the Eden2 scheduler. + # rate in the Sched3 scheduler. # We want the final LR to be based on the geometric mean of "how much data we # have seen" and "how many batches we have seen". # an easier way to look at it is this: the formula for learning rate depends @@ -1442,7 +1442,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Eden2(optimizer, get_adjusted_lr_batches(params)) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 0a20303bacb0c7b030fad13677026425bf8956aa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Apr 2025 10:29:33 +0800 Subject: [PATCH 0345/1191] Fix formula of sched3 scheduler. --- egs/librispeech/ASR/zipformer/optim.py | 34 ++++++++++++++++++++------ 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 158130002f..21c9f9787a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1057,13 +1057,31 @@ class Sched3(LRScheduler): also be, say, 0.8. lr_batches is a number of batches that defines when we start decreasing significantly. "batch" is the current batch count. - lr = warmup * min((lr_batches / batch)^p, exp(-batch / (e * lr_batches))) + lr = warmup * [ (p * lr_batches / batch)^p if batch > p*e*lr_batches, else + exp(-batch / (e * lr_batches))) where e is the mathematical constant e. This expression is equivalent to: - min_q [ (q * lr_batches) / batch)^q ] where the minimum is taken over + factor = min_q [ (q * lr_batches) / batch)^q ] where the minimum is taken over the continuous range 0 <= q <= p. The left hand side of the min in the formula for lr corresponds to q == p, i.e. we hit the rhs of the allowed range. + * notes for derivation: define x == lr_batches/batch, and let factor=min_q [(q*x) +. In wolframalpha.com, note that: + d/dp (q * x)^q has a root at (q = 1/(ex)). If 1/ex > p, then q is fixed to the limit, + q==p, so factor == (p * x)^p. Else, i.e. when 1/ex <= p, + when p > 1 / ex, factor == (q * x)^1 = (1/(ex)*x)^(1/ex) = (1/e)^(1/ex = e^{-1/ex}. + + So the rule is: + if batch/(e*lr_batches) > p, i.e. if batch > p*e*lr_batches, + factor = (p * lr_batches/batch)^p. + else, factor = exp(-batch/(lr_batches*e)) + Plot[ If [ x > 0.8 * Exp[1] * 10, 0.8*10/x, Exp[-x/(10*Exp[1])] ], {x, 0, 50}] + + + + + + `warmup` increases linearly from warmup_start to 1 over `warmup_batches` batches and then stays constant at 1. @@ -1093,9 +1111,11 @@ def __init__( def get_lr(self): lr_batches = self.lr_batches - batch = max(self.batch, 0.1) # avoid division by zero - factor = min((lr_batches / batch) ** self.p, - 2.71828 ** (-batch / (2.71828 * lr_batches))) + e = 2.71828 + batch = self.batch + p = self.p + factor = ((p * lr_batches / batch) ** p if batch > p * e * lr_batches else + e ** (-batch / (e * lr_batches))) warmup_factor = ( 1.0 @@ -1140,10 +1160,10 @@ def _test_sched3(): m = torch.nn.Linear(100, 100) optim = TransformedAdam(m.parameters(), lr=0.03) - scheduler = Sched3(optim, lr_batches=100, verbose=True, warmup_batches=20) + scheduler = Sched3(optim, lr_batches=100, p=0.8, verbose=True, warmup_batches=20) - for step in range(200): + for step in range(300): x = torch.randn(200, 100).detach() x.requires_grad = True y = m(x) From b5e40838db29736f4a9804c7a0abc38e338eb904 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Apr 2025 16:15:24 +0800 Subject: [PATCH 0346/1191] Set weights_only to True in torch.load --- icefall/checkpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d31ce13019..69c05d7ca1 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -110,7 +110,7 @@ def load_checkpoint( TODO: document it """ logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu") + checkpoint = torch.load(filename, map_location="cpu", weights_only=True) if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") @@ -163,7 +163,7 @@ def average_checkpoints( """ n = len(filenames) - avg = torch.load(filenames[0], map_location=device)["model"] + avg = torch.load(filenames[0], map_location=device, weights_only=True)["model"] # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -178,7 +178,7 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - state_dict = torch.load(filenames[i], map_location=device)["model"] + state_dict = torch.load(filenames[i], map_location=device, weights_only=True)["model"] for k in uniqued_names: avg[k] += state_dict[k] @@ -421,8 +421,8 @@ def average_checkpoints_with_averaged_model( device: Move checkpoints to this device before averaging. """ - state_dict_start = torch.load(filename_start, map_location=device) - state_dict_end = torch.load(filename_end, map_location=device) + state_dict_start = torch.load(filename_start, map_location=device, weights_only=True) + state_dict_end = torch.load(filename_end, map_location=device, weights_only=True) average_period = state_dict_start["average_period"] From 3f726fe5c757066c0f8607352cc5fea1dda8c99a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Apr 2025 16:28:29 +0800 Subject: [PATCH 0347/1191] Set weights_only to False in torch.load --- icefall/checkpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 69c05d7ca1..045793a609 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -110,7 +110,7 @@ def load_checkpoint( TODO: document it """ logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu", weights_only=True) + checkpoint = torch.load(filename, map_location="cpu", weights_only=False) if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") @@ -163,7 +163,7 @@ def average_checkpoints( """ n = len(filenames) - avg = torch.load(filenames[0], map_location=device, weights_only=True)["model"] + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -178,7 +178,7 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - state_dict = torch.load(filenames[i], map_location=device, weights_only=True)["model"] + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] for k in uniqued_names: avg[k] += state_dict[k] @@ -421,8 +421,8 @@ def average_checkpoints_with_averaged_model( device: Move checkpoints to this device before averaging. """ - state_dict_start = torch.load(filename_start, map_location=device, weights_only=True) - state_dict_end = torch.load(filename_end, map_location=device, weights_only=True) + state_dict_start = torch.load(filename_start, map_location=device, weights_only=False) + state_dict_end = torch.load(filename_end, map_location=device, weights_only=False) average_period = state_dict_start["average_period"] From 0bbec759aea84b38f82929b2d3073081404d8bad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Apr 2025 15:24:27 +0800 Subject: [PATCH 0348/1191] Default config changes: num-encoder-layers 3,4,4,4,4,4,4,4,4->3,4,6,4,4,6,4, feedforward-multiple 4->3, encoder-multiple 3,4,6,6,8,8,6,6,4->3,5,8,12,12,8,5 --- egs/librispeech/ASR/zipformer/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index e3cbe72bba..8f8ed06df2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -185,14 +185,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="3,4,4,4,4,4,4,4,4", + default="3,4,6,4,4,6,4", help="Number of zipformer encoder layers per stack, comma separated.", ) parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,4,8,8,4,4,2", + default="1,2,4,8,8,4,2", help="Downsampling factor for each stack of encoder layers.", ) @@ -213,21 +213,21 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-multiple", type=str, - default="4,4,4,4,4,4,4,4,4", + default="3,3,3,3,3,3,3", help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", ) parser.add_argument( "--num-heads", type=str, - default="4,4,4,4,8,8,4,4,4", + default="4,4,4,8,8,4,4", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) parser.add_argument( "--encoder-multiple", type=str, - default="3,4,6,6,8,8,6,6,4", + default="3,5,8,12,12,8,5", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31,31,15,15,15,15,15,15,31", + default="31,31,15,15,15,15,31", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) From a29d2667800c7cd5946ef10020d1083e47ea8b9c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Apr 2025 17:29:45 +0800 Subject: [PATCH 0349/1191] Use beta1 = min(beta1, 1. - 1. / (10. + 0.25 * step)); was previously 0.5 * step. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9186632839..3eca1b4771 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -156,7 +156,7 @@ def momentum_step(group, p, state, grad): lr = group["lr"] step = state["step"] - beta1 = min(group["beta1"], 1. - 1. / (10. + 0.5 * step)) + beta1 = min(group["beta1"], 1. - 1. / (10. + 0.25 * step)) direct = group["direct"] try: From dd3ded4d634fe54c8c19a97b55fb4c3eed583c34 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Apr 2025 11:51:16 +0800 Subject: [PATCH 0350/1191] Increase beta1 further from .995 to .9975 # Conflicts: # egs/librispeech/ASR/zipformer/optim.py --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3eca1b4771..cce08327ad 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -385,7 +385,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.98, + beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scalar_lr_scale=0.1, From 4cee8c88842d4f5d7f41e85ac7d9d97be3b39f30 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Apr 2025 17:21:26 +0800 Subject: [PATCH 0351/1191] Increase interior multiple from 12 to 14. --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 8f8ed06df2..a9351f886a 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -227,7 +227,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-multiple", type=str, - default="3,5,8,12,12,8,5", + default="3,5,8,14,14,8,5", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) From cc2790db55258f842f7b6bed3820973435e7f65f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Apr 2025 19:59:45 +0800 Subject: [PATCH 0352/1191] Implement much-simplified version of TransformedAdam, called SimpleTransformedAdam. --- egs/librispeech/ASR/zipformer/optim.py | 134 +++++++++++++++++++++---- 1 file changed, 116 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 21c9f9787a..39be2551ac 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -365,21 +365,10 @@ class TransformedAdam(BatchedOptimizer): weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of learning the scale on the parameters. Weight tensors are defined as anything with more than one element and ndim > 1. - weight_penalty_rms: Value of root-mean-square value of weight tensor, that provides - a reference point for when we start to do adamw-style decay. bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with more than one element and exactly one tensor dimension i.e. ndim == 1. - bias_penalty_rms: Value of root-mean-square value of bias tensor, that provides - a reference point for when we start to do adamw-style decay. - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period debug_interval: if >0, write some statistics to tensorboard every this-many steps. """ - def __init__( self, params, @@ -413,7 +402,6 @@ def __init__( bias_max_rms=bias_max_rms, bias_min_rms=bias_min_rms, weight_max_rms=weight_max_rms, - size_update_period=size_update_period, clipping_update_period=clipping_update_period, debug_interval=debug_interval, ) @@ -834,6 +822,111 @@ def _show_gradient_dominating_parameter( f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) +class SimpleTransformedAdam(Optimizer): + """ + Version of TransformedAdam that doesn't do the batching or gradient clipping (may be easier to integrate + into other frameworks). + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses). + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. + Must satisfy 0 < beta <= beta2 < 1. + betas: a list of decay constants for momentum on the parameter-change + scales: a list of scales corresponding to each of the betas, that we multiply + each momentum-update by. Implicitly there is also a beta=0, scale=1, + i.e. a non-decayed update. + scaling_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each non-scalar parameter tensor. If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. + eps: A general-purpose epsilon to prevent division by zero + weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of + learning the scale on the parameters. Weight tensors are defined + as anything with more than one element and ndim > 1. + bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with + more than one element and exactly one tensor dimension i.e. ndim == 1. + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + beta1=0.995, + direct=0.05, # scale on bypass of momentum (beta1) + beta2=0.98, + scalar_lr_scale=0.1, + scaling_lr_scale=0.1, + eps=1.0e-08, + weight_min_rms=0.005, + weight_max_rms=1.0, + bias_min_rms=1.0e-05, + bias_max_rms=5.0, + debug_interval=0, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + beta1=beta1, + direct=direct, + beta2=beta2, + scalar_lr_scale=scalar_lr_scale, + scaling_lr_scale=scaling_lr_scale, + eps=eps, + weight_min_rms=weight_min_rms, + bias_max_rms=bias_max_rms, + bias_min_rms=bias_min_rms, + weight_max_rms=weight_max_rms, + debug_interval=debug_interval, + ) + super().__init__(params, defaults) + + self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook) + + + + def __setstate__(self, state): + super(TransformedAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group in self.param_groups: + + for p in group['params']: + state = self.state[p] + grad = p.grad + + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + p[:] = debug_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 + + return loss class LRScheduler(object): @@ -1323,7 +1416,7 @@ def step(self, closure=None): return loss -def _test_scaled_adam(hidden_dim: int): +def _test_transformed_adam(hidden_dim: int): import timeit from scaling import ScaledLinear, OrthogonalLinear @@ -1331,7 +1424,7 @@ def _test_scaled_adam(hidden_dim: int): E = 100 B = 4 T = 2 - logging.info("in test_eve_cain") + logging.info("in test_transformed_adam") # device = torch.device('cuda') device = torch.device("cpu") dtype = torch.float32 @@ -1343,7 +1436,7 @@ def _test_scaled_adam(hidden_dim: int): input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - for iter in [1, 0]: + for iter in [0, 1, 2]: fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear @@ -1368,9 +1461,14 @@ def _test_scaled_adam(hidden_dim: int): ] if iter == 0: - optim = Eve(m.parameters(), lr=0.003) + optim = SimpleTransformedAdam(m.parameters(), lr=0.06, eps=1.0e-20) elif iter == 1: optim = TransformedAdam(m.named_parameters(), lr=0.06, clipping_scale=2.0, eps=1.0e-20) + elif iter == 2: + optim = Eve(m.parameters(), lr=0.003) + else: + assert "unknown iter", iter + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() @@ -1452,7 +1550,7 @@ def _test_transform_params(): else: hidden_dim = 200 - #_test_transform_params() - #_test_scaled_adam(hidden_dim) + _test_transform_params() + _test_transformed_adam(hidden_dim) _test_eden() _test_sched3() From 581f7ccbd1f8c1c24c79469dfe654411b6a3900b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Apr 2025 20:47:08 +0800 Subject: [PATCH 0353/1191] Create fake batch dim --- egs/librispeech/ASR/zipformer/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 39be2551ac..0b8882be4b 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -922,7 +922,9 @@ def step(self, closure=None): state["step"] = 0 cur_step = 0 - p[:] = debug_step(group, p.detach(), state, grad) + def u(x): + return x.unsqueeze(0) + p[:] = debug_step(group, u(p.detach()), state, u(grad))[0] state["step"] = cur_step + 1 From 5f3e283a0ba5610b26007a09565e2bafd429f29a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 13:42:33 +0800 Subject: [PATCH 0354/1191] Change encoder-multiple from 3,5,8,14,14,8,5 to 3,4,8,12,12,8,4 --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index a9351f886a..d0c81bf5ef 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -227,7 +227,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-multiple", type=str, - default="3,5,8,14,14,8,5", + default="3,4,8,12,12,8,4", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) From 4fdfa08a2a3c0d91b9047dae231f19ee8f01f586 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 14:09:08 +0800 Subject: [PATCH 0355/1191] Add two layers, now 3,4,7,5,4,6,4 --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index d0c81bf5ef..99c81ea2ce 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="3,4,6,4,4,6,4", + default="3,4,7,5,4,6,4", help="Number of zipformer encoder layers per stack, comma separated.", ) From e8a8216c80ab5d7d7fa6580b05fc00715fc4e13a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 16:14:17 +0800 Subject: [PATCH 0356/1191] Add two more layers, nearly equivalent to 524. --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 99c81ea2ce..7d4e5e4050 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="3,4,7,5,4,6,4", + default="3,5,7,5,4,7,5", help="Number of zipformer encoder layers per stack, comma separated.", ) From f01420974bbbad03a68b96f8f24e99578d8258f4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 16:42:04 +0800 Subject: [PATCH 0357/1191] Get rid of autocast warnings. --- egs/librispeech/ASR/zipformer/model.py | 4 ++-- egs/librispeech/ASR/zipformer/scaling.py | 10 +++++----- egs/librispeech/ASR/zipformer/train.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 36799e70d3..28f58654e4 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -305,7 +305,7 @@ def forward_transducer( # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -340,7 +340,7 @@ def forward_transducer( # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c0b1946b4e..3a9eae2760 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -307,7 +307,7 @@ def forward(ctx, x: Tensor, dim: int): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -409,7 +409,7 @@ def forward( def backward(ctx, ans_grad: Tensor) -> Tensor: x, scale = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): x, scale = x.to(torch.float32), scale.to(torch.float32) x, scale = x.detach(), scale.detach() @@ -580,7 +580,7 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, # get the indexes. project, then mean-and-variance-norm, then # take mx. x_proj = torch.matmul(x, proj_weight.t()) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): x_proj = x_proj.to(torch.float) # Mean subtraction and variance normalization. dims = tuple(range(0, x.ndim - 1)) @@ -981,7 +981,7 @@ def balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, max_rms, grad # this was taken out of the Balancer backward function. # returns modified version of x_grad. with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1336,7 +1336,7 @@ def backward(ctx, x_grad: Tensor): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index fa5ec04bd6..3863f9af92 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1196,7 +1196,7 @@ def save_bad_model(suffix: str = ""): batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( + with torch.amp.autocast('cuda', enabled=params.use_autocast, dtype=params.dtype ): loss, loss_info = compute_loss( @@ -1644,7 +1644,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( + with torch.amp.autocast('cuda', enabled=params.use_autocast, dtype=params.dtype ): loss, _ = compute_loss( From 556a886992df1981640406b5ad4eb92bcd5f1565 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 16:57:40 +0800 Subject: [PATCH 0358/1191] Some code cleanup --- egs/librispeech/ASR/zipformer/decoder.py | 1 - egs/librispeech/ASR/zipformer/scaling.py | 292 ------------------ .../ASR/zipformer/scaling_converter.py | 7 +- 3 files changed, 2 insertions(+), 298 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 357f98a807..bf49726b95 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from scaling import Balancer class Decoder(nn.Module): diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3a9eae2760..808c68f321 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -252,37 +252,6 @@ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: return torch.where(is_too_small, random_val, x).to(torch.float16) -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 - - def __call__(self, x: float) -> bool: - """ - Returns true if x is above the cutoff. - """ - ans = x > self.cutoff - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1 - q) - return ans class SoftmaxFunction(torch.autograd.Function): @@ -977,218 +946,6 @@ def streaming_forward( return x_chunk + x_causal, cache -def balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim): - # this was taken out of the Balancer backward function. - # returns modified version of x_grad. - with torch.enable_grad(): - with torch.amp.autocast('cuda', enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [i for i in range(x.ndim) if i != channel_dim] - uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = m_loss + r_loss - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = ( - (loss_grad**2) - .mean(dim=mean_dims, keepdim=True) - .sqrt() - .clamp(min=1.0e-20) - ) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - return x_grad - - -class BalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - (x,) = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - x_grad = balancer_backward_func(x, x_grad, min_mean, max_mean, min_rms, - max_rms, grad_scale, channel_dim) - except Exception as e: - logging.info( - f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will balance a part of it." - ) - try: - # will take a piece of x_grad in this dimension. - dim_to_split = 0 if channel_dim != 0 else 1 - size = x.shape[dim_to_split] - - x_grad_part = balancer_backward_func(x.narrow(dim_to_split, 0, size // 4), - x_grad.narrow(dim_to_split, 0, size // 4), - min_mean, max_mean, min_rms, - max_rms, grad_scale, channel_dim) - del x # save memory - x_grad = torch.cat([x_grad_part, x_grad.narrow(dim_to_split, - size // 4, - size - size // 4)], - dim_to_split) - except Exception as e: - logging.info( - f"Caught exception second time in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." - ) - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or not x.requires_grad - or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) - ): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - class ScaleLimiterFunction(torch.autograd.Function): @staticmethod @@ -1917,55 +1674,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) -def _test_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = Balancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_sign: x = ", x) - print("_test_balancer_sign: y grad = ", y_grad) - print("_test_balancer_sign: x grad = ", x.grad) - - -def _test_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = Balancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - min_abs=0.2, - max_abs=0.7, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_magnitude: x = ", x) - print("_test_balancer_magnitude: y grad = ", y_grad) - print("_test_balancer_magnitude: x grad = ", x.grad) - - def _test_double_swish_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 1f95648a07..5a30b1e89b 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -17,9 +17,7 @@ """ This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. +Specifically, Whiten is replaced with an identity operator. """ import copy @@ -28,7 +26,6 @@ import torch import torch.nn as nn from scaling import ( - Balancer, Dropout3, ScaleGrad, SwooshL, @@ -83,7 +80,7 @@ def convert_scaled_to_non_scaled( d = {} for name, m in model.named_modules(): - if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): + if isinstance(m, (Dropout3, ScaleGrad, Whiten)): d[name] = nn.Identity() elif is_onnx and isinstance(m, SwooshR): d[name] = SwooshROnnx() From 77c4f4851104cd43d215beed2c553643eb6a65ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 21:53:26 +0800 Subject: [PATCH 0359/1191] Code cleanups, use torch.compile for swoosh. --- egs/librispeech/ASR/zipformer/my_profile.py | 2 +- egs/librispeech/ASR/zipformer/scaling.py | 220 ++++++------------ .../ASR/zipformer/scaling_converter.py | 16 +- egs/librispeech/ASR/zipformer/subsampling.py | 12 +- egs/librispeech/ASR/zipformer/test_scaling.py | 10 +- egs/librispeech/ASR/zipformer/zipformer.py | 4 +- 6 files changed, 87 insertions(+), 177 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/my_profile.py b/egs/librispeech/ASR/zipformer/my_profile.py index f87613eb08..f9b26969f7 100755 --- a/egs/librispeech/ASR/zipformer/my_profile.py +++ b/egs/librispeech/ASR/zipformer/my_profile.py @@ -66,7 +66,7 @@ def _bias_norm_flops_compute(module, input, output): def _swoosh_module_flops_compute(module, input, output): - # For SwooshL and SwooshR modules + # For SwashL and SwashR modules assert len(input) == 1, len(input) # estimate as swish/silu, see icefall/profiler.py flops = input[0].numel() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 808c68f321..d5e2a6145b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1294,78 +1294,6 @@ def forward(self, x): return _no_op(x) -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16 or x.dtype == torch.bfloat16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) # Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. @@ -1419,75 +1347,57 @@ def forward(self, x: Tensor) -> Tensor: -def _swoosh_l_forward_wrapper(x): - return 0.25 * k2.swoosh_l_forward(x * 4) -def _swoosh_r_forward_wrapper(x): - return 0.25 * k2.swoosh_r_forward(x * 4) -def _swoosh_l_forward_and_deriv_wrapper(x): - y, dy_dx = k2.swoosh_l_forward_and_deriv(x * 4) - return 0.25 * y, dy_dx -def _swoosh_r_forward_and_deriv_wrapper(x): - y, dy_dx = k2.swoosh_r_forward_and_deriv(x * 4) - return 0.25 * y, dy_dx +def torch_compile(fn, *args, **kwargs): + if hasattr(torch, 'compile'): + fn = torch.compile(fn, *args, **kwargs) + return fn +def swashl(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 +def swashr(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 -class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - if torch.jit.is_scripting() or torch.jit.is_tracing(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 - if not x.requires_grad: - return _swoosh_l_forward_wrapper(x) - else: - return 0.25 * k2.swoosh_l(x * 4) +def swashl_and_deriv(x: Tensor): + x_offset = 4. * x - 4. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.00875 + return y, deriv -class SwooshLOnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return 0.25 * logaddexp_onnx(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 +def swashr_and_deriv(x: Tensor): + x_offset = 4. * x - 1. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.07831542175 + return y, deriv +swashl_compiled = torch_compile(swashl) +swashr_compiled = torch_compile(swashr) +swashl_and_deriv_compiled = torch_compile(swashl_and_deriv) +swashr_and_deriv_compiled = torch_compile(swashr_and_deriv) -class SwooshR(torch.nn.Module): +class SwashL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - if torch.jit.is_scripting() or torch.jit.is_tracing(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 - if not x.requires_grad: - return _swoosh_r_forward_wrapper(x) - else: - return 0.25 * k2.swoosh_r(4 * x) + """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + return swashl_compiled(x) - -class SwooshROnnx(torch.nn.Module): +class SwashR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return 0.25 * logaddexp_onnx(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = 4 * x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return 0.25 * log_sum - 0.08 * x - 0.00875 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = 4 * x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return 0.25 * log_sum - 0.08 * x - 0.07831542175 - + """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + return swashr_compiled(x) @@ -1519,8 +1429,8 @@ def forward( ctx.activation = activation forward_activation_dict = { - "SwooshL": _swoosh_l_forward_wrapper, - "SwooshR": _swoosh_r_forward_wrapper, + "SwashL": swashl_compiled, + "SwashR": swashr_compiled, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1538,8 +1448,8 @@ def backward(ctx, ans_grad: Tensor): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - "SwooshL": _swoosh_l_forward_and_deriv_wrapper, - "SwooshR": _swoosh_r_forward_and_deriv_wrapper, + "SwashL": swashl_and_deriv_compiled, + "SwashR": swashr_and_deriv_compiled, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1569,9 +1479,9 @@ class ActivationDropoutAndLinear(torch.nn.Module): """ This merges an activation function followed by dropout and then a nn.Linear module; it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be + module. If activation == SwashL and dropout_shared_dim != None, this will be equivalent to: - nn.Sequential(SwooshL(), + nn.Sequential(SwashL(), Dropout3(dropout_p, shared_dim=dropout_shared_dim), ScaledLinear(in_channels, out_channels, bias=bias, initial_scale=initial_scale)) @@ -1583,7 +1493,7 @@ class ActivationDropoutAndLinear(torch.nn.Module): in_channels: number of input channels, e.g. 256 out_channels: number of output channels, e.g. 256 bias: if true, have a bias - activation: the activation function, for now just support SwooshL. + activation: the activation function, for now just support SwashL, SwashR. dropout_p: the dropout probability or schedule (happens after nonlinearity). dropout_shared_dim: the dimension, if any, across which the dropout mask is shared (e.g. the time dimension). If None, this may be less memory @@ -1595,7 +1505,7 @@ def __init__( in_channels: int, out_channels: int, bias: bool = True, - activation: str = "SwooshL", + activation: str = "SwashL", dropout_p: FloatLike = 0.0, dropout_shared_dim: Optional[int] = -1, initial_scale: float = 1.0, @@ -1620,10 +1530,10 @@ def __init__( def forward(self, x: Tensor): if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == "SwooshL": - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) + if self.activation == "SwashL": + x = swashl_compiled(x) + elif self.activation == "SwashR": + x = swashr_compiled(x) else: assert False, self.activation return torch.nn.functional.linear(x, self.weight, self.bias) @@ -1688,10 +1598,10 @@ def _test_double_swish_deriv(): y = m(x) -def _test_swooshl_deriv(): +def _test_swashl_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True - m = SwooshL() + m = SwashL() tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) @@ -1702,10 +1612,10 @@ def _test_swooshl_deriv(): y = m(x) -def _test_swooshr_deriv(): +def _test_swashr_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True - m = SwooshR() + m = SwashR() tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) @@ -1763,12 +1673,12 @@ def _test_activation_dropout_and_linear(): for bias in [True, False]: # actually we don't test for dropout_p != 0.0 because forward functions will give # different answers. This is because we are using the k2 implementation of - # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # swash_l an swash_r inside SwashL() and SwashR(), and they call randn() # internally, messing up the random state. for dropout_p in [0.0]: - for activation in ["SwooshL", "SwooshR"]: + for activation in ["SwashL", "SwashR"]: m1 = nn.Sequential( - SwooshL() if activation == "SwooshL" else SwooshR(), + SwashL() if activation == "SwashL" else SwashR(), Dropout3(p=dropout_p, shared_dim=-1), ScaledLinear( in_channels, out_channels, bias=bias, initial_scale=0.5 @@ -1808,6 +1718,9 @@ def _test_activation_dropout_and_linear(): print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) + print("grad1 = ", m1[2].weight.grad) + print("grad2 = ", m2.weight.grad) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) @@ -1820,7 +1733,7 @@ def isclose(a, b): (a**2).sum() * (b**2).sum() ).sqrt() - # the SwooshL() implementation has a noisy gradient due to 1-byte + # the SwashL() implementation has a noisy gradient due to 1-byte # storage of it. assert isclose(x1.grad, x2.grad) @@ -1836,10 +1749,7 @@ def _test_orthogonal_linear(): _test_piecewise_linear() _test_softmax() _test_whiten() - _test_balancer_sign() - _test_balancer_magnitude() - _test_double_swish_deriv() - _test_swooshr_deriv() - _test_swooshl_deriv() + _test_swashr_deriv() + _test_swashl_deriv() _test_activation_dropout_and_linear() _test_orthogonal_linear() diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 5a30b1e89b..933ebd8e58 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -28,10 +28,10 @@ from scaling import ( Dropout3, ScaleGrad, - SwooshL, - SwooshLOnnx, - SwooshR, - SwooshROnnx, + SwashL, + SwashLOnnx, + SwashR, + SwashROnnx, Whiten, ) from zipformer import CompactRelPositionalEncoding @@ -82,10 +82,10 @@ def convert_scaled_to_non_scaled( for name, m in model.named_modules(): if isinstance(m, (Dropout3, ScaleGrad, Whiten)): d[name] = nn.Identity() - elif is_onnx and isinstance(m, SwooshR): - d[name] = SwooshROnnx() - elif is_onnx and isinstance(m, SwooshL): - d[name] = SwooshLOnnx() + elif is_onnx and isinstance(m, SwashR): + d[name] = SwashROnnx() + elif is_onnx and isinstance(m, SwashL): + d[name] = SwashLOnnx() elif is_onnx and isinstance(m, CompactRelPositionalEncoding): # We want to recreate the positional encoding vector when # the input changes, so we have to use torch.jit.script() diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index f47befd325..ce0617b3b4 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -31,8 +31,8 @@ ScaledConv2d, ScaleGrad, ScheduledFloat, - SwooshL, - SwooshR, + SwashL, + SwashR, Whiten, ) from torch import Tensor, nn @@ -65,7 +65,7 @@ def __init__( in_channels=channels, out_channels=hidden_channels, kernel_size=1, ) - self.activation = SwooshL() + self.activation = SwashL() self.pointwise_conv2 = nn.Conv2d( in_channels=hidden_channels, @@ -190,7 +190,7 @@ def __init__( padding=(0, 1), # (time, freq) ), ScaleGrad(0.2), - SwooshR(), + SwashR(), nn.Conv2d( in_channels=layer1_channels, out_channels=layer2_channels, @@ -198,14 +198,14 @@ def __init__( stride=2, padding=0, ), - SwooshR(), + SwashR(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=(1, 2), # (time, freq) ), - SwooshR(), + SwashR(), ) diff --git a/egs/librispeech/ASR/zipformer/test_scaling.py b/egs/librispeech/ASR/zipformer/test_scaling.py index 5c04291e73..1c6e655ee4 100755 --- a/egs/librispeech/ASR/zipformer/test_scaling.py +++ b/egs/librispeech/ASR/zipformer/test_scaling.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import torch -from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR +from scaling import PiecewiseLinear, ScheduledFloat, SwashL, SwashR def test_piecewise_linear(): @@ -52,8 +52,8 @@ def test_swoosh(): x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32) x = torch.cat([x1, x2[1:]]) - left = SwooshL()(x) - r = SwooshR()(x) + left = SwashL()(x) + r = SwashR()(x) relu = torch.nn.functional.relu(x) print(left[x == 0], r[x == 0]) @@ -63,8 +63,8 @@ def test_swoosh(): plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax] plt.legend( [ - "SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ", - "SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687", + "SwashL(x) = 0.25 * log(1 + exp(4*x-4)) - 0.08x - 0.00875", + "SwashR(x) = 0.25 * log(1 + exp(4*x-1)) - 0.08x - 0.07831542175", "ReLU(x) = max(0, x)", ] ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 231bed93aa..c5256998da 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1624,7 +1624,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): self.out_proj = ActivationDropoutAndLinear( feedforward_dim, embed_dim, - activation="SwooshL", + activation="SwashL", dropout_p=dropout, dropout_shared_dim=0, bias=True, @@ -1860,7 +1860,7 @@ def __init__( self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, - activation="SwooshR", + activation="SwashR", dropout_p=0.0, initial_scale=0.05, ) From 1ee892dfcb76ecf977c11d2b687c894a0badb918 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Apr 2025 22:07:18 +0800 Subject: [PATCH 0360/1191] Bug fix --- egs/librispeech/ASR/zipformer/subsampling.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index ce0617b3b4..67ea511c4e 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -17,17 +17,15 @@ # limitations under the License. import warnings -from typing import Tuple +from typing import Tuple, Optional import torch from scaling import ( - Balancer, ScaleLimiter, ScaledLinear, ExpNorm, Dropout3, FloatLike, - Optional, ScaledConv2d, ScaleGrad, ScheduledFloat, From 6033b40f0adc3e0d0f1c9df0a0b9f2f4e41b65d8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 4 May 2025 15:28:49 +0800 Subject: [PATCH 0361/1191] Increase encoder-multiple from 3,4,8,12,12,8,4 to 4,6,9,12,12,9,6. --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3863f9af92..53f679d8af 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -227,7 +227,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-multiple", type=str, - default="3,4,8,12,12,8,4", + default="4,6,9,12,12,9,6", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) From 4c13d3fefa03c99cf8a992ee58765bf01a468b81 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 4 May 2025 18:51:18 +0800 Subject: [PATCH 0362/1191] Make PredictLoss be a regression loss based on flow matching --- egs/librispeech/ASR/zipformer/scaling.py | 62 +++++++++++----------- egs/librispeech/ASR/zipformer/zipformer.py | 3 +- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d5e2a6145b..1b3b8ff982 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -534,39 +534,37 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, +def predict_loss(x: Tensor, predictor: nn.Module, t: float, batch_dim: int, name: str, mask: Optional[Tensor]) -> Tensor: batch_size = x.shape[batch_dim] + if batch_size % 2 != 0: assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) + def mean_and_variance_norm(x): + mean = x.mean(dim=list(range(x.ndim-1))) + x = x - mean + eps = 1.0e-08 + stddev = ((x ** 2).mean(dim=list(range(x.ndim-1))) + eps).sqrt() + x = x / stddev + return x + + if mask is not None: mask = mask.to(x.dtype) with torch.no_grad(): - # get the indexes. project, then mean-and-variance-norm, then - # take mx. - x_proj = torch.matmul(x, proj_weight.t()) - with torch.amp.autocast('cuda', enabled=False): - x_proj = x_proj.to(torch.float) - # Mean subtraction and variance normalization. - dims = tuple(range(0, x.ndim - 1)) - if mask is not None: - x_masked = x_proj * mask - x_proj = x_proj - x_masked.sum(dim=dims) / mask.sum(dim=dims) - x_proj = x_proj * (mask.sum(dim=dims) / ((x_masked ** 2).sum(dim=dims) + 1.0e-10)).sqrt() - else: - x_proj = x_proj - x_proj.mean(dim=dims) - x_proj = x_proj / (x_proj ** 2).mean(dim=dims).sqrt() + x_swapped = torch.roll(x, batch_size // 2, batch_dim) + x_swapped = mean_and_variance_norm(x_swapped) + rand = torch.randn_like(x_swapped) + x_interp = (t * x_swapped) + (1 - t) * rand + u_t = x_swapped - rand # reference "velocity" - indexes = torch.max(x_proj, dim=-1)[1] + v_t = predictor(torch.cat((x_interp, x), dim=-1)) # predicted "velocity" - indexes = torch.roll(indexes, batch_size // 2, batch_dim) - x_pred = predictor(x) - logprobs = x_pred.log_softmax(dim=-1) - loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) + loss = ((u_t - v_t) ** 2).mean(dim=-1) if random.random() < 0.002: logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") @@ -578,29 +576,29 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, class PredictLoss(nn.Module): """ - Adds an auxiliary loss based on predicting the top-1 of 256 randomized codebook - entries. + Adds an auxiliary loss based on predicting the (noise - x) direction given two inputs: + the "x" value from the "other copy of the data", and the input ((1-t) * noise + t * x), + aas in flow matching. So a pretext task based on flow matching. + The "t" value is specified by the user, strictly between 0 and 1; smaller "t" means more noise, + larger "t" means closer to x. Smaller "t" will concentrate on the broader contours + of the distrbution. """ def __init__(self, num_channels: int, batch_dim: int = 0, - codebook_size: int = 63): + t: float = 0.2): super().__init__() - scale = num_channels ** -0.5 - self.register_buffer('proj_weight', - scale * torch.randn(codebook_size, num_channels), - persistent=True) - num_hidden = max(1024, num_channels) - self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), + num_hidden = max(1024, 2 * num_channels) + self.predictor = nn.Sequential(nn.Linear(2 * num_channels, num_hidden), nn.LeakyReLU(), - nn.Linear(num_hidden, codebook_size)) + nn.Linear(num_hidden, num_channels)) self.batch_dim = batch_dim self.name = None # will be set from training code - + self.t = t def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - return predict_loss(x, self.predictor, self.proj_weight, + return predict_loss(x, self.predictor, self.t, self.batch_dim, self.name, mask) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index c5256998da..a648a5eb3d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -802,7 +802,7 @@ def forward( bypass = self.copy_bypass(bypass) src = torch.cat((src, bypass), dim=-1) - return src, self.predict_loss(src, (src_key_padding_mask.t().unsqueeze(-1).logical_not() + return src, self.predict_loss(src, (src_key_padding_mask.t().logical_not() if src_key_padding_mask is not None else None)) def streaming_forward( @@ -1971,7 +1971,6 @@ def forward(self, x): def _test_zipformer_main(causal: bool = False): - batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. From fdc57ca8aaaa5b93f27479c6d18f6b92c86b090c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 May 2025 12:00:39 +0800 Subject: [PATCH 0363/1191] Use four repeats of the random noise in predict_loss (not memory efficient impl.) --- egs/librispeech/ASR/zipformer/scaling.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1b3b8ff982..79269b5d65 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -586,7 +586,8 @@ class PredictLoss(nn.Module): def __init__(self, num_channels: int, batch_dim: int = 0, - t: float = 0.2): + t: float = 0.2, + num_repeats: int = 4): super().__init__() num_hidden = max(1024, 2 * num_channels) self.predictor = nn.Sequential(nn.Linear(2 * num_channels, num_hidden), @@ -595,11 +596,19 @@ def __init__(self, self.batch_dim = batch_dim self.name = None # will be set from training code self.t = t + self.num_repeats = num_repeats # to reduce variance of gradient (since this module draws random values). def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - return predict_loss(x, self.predictor, self.t, - self.batch_dim, self.name, mask) + # x is of shape (..., num_channels); mask is of shape (...), i.e. + # it matches x except is missing the last dim. + # CAUTION: the part with "repeats" actually assumes that the time dim + # is dim zero and batch dim is dim 1. + assert self.batch_dim == 1 + r = self.num_repeats + return predict_loss(x.repeat(r, 1, 1), self.predictor, self.t, + self.batch_dim, self.name, + mask.repeat(r, 1) if mask is not None else None) From e1fe22366401527b7b9fe8619c2bb14ef10067ae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 May 2025 12:19:30 +0800 Subject: [PATCH 0364/1191] Use two not four repeats for OOM --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 79269b5d65..cd067c8360 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -587,7 +587,7 @@ def __init__(self, num_channels: int, batch_dim: int = 0, t: float = 0.2, - num_repeats: int = 4): + num_repeats: int = 2): super().__init__() num_hidden = max(1024, 2 * num_channels) self.predictor = nn.Sequential(nn.Linear(2 * num_channels, num_hidden), From 046e7a16962c1e8ea486975cac9493473fd69535 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 May 2025 12:32:29 +0800 Subject: [PATCH 0365/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index cd067c8360..87bca224d4 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -608,7 +608,7 @@ def forward(self, r = self.num_repeats return predict_loss(x.repeat(r, 1, 1), self.predictor, self.t, self.batch_dim, self.name, - mask.repeat(r, 1) if mask is not None else None) + mask.repeat(r, 1) if mask is not None else None) / r From f5a59300c4ae4fc0a452ff08341195fd7e2432fb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 May 2025 14:04:57 +0800 Subject: [PATCH 0366/1191] Take masking-related changes from deterministic_invertible617conv --- egs/librispeech/ASR/zipformer/model.py | 21 ++++++++++++++++++--- egs/librispeech/ASR/zipformer/scaling.py | 5 ++--- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++++++++++++++++-- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 28f58654e4..3fc7a5f06a 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -24,7 +24,7 @@ from torch import Tensor from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment -from scaling import ScaledLinear +from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp @@ -149,13 +149,20 @@ def forward_encoder( Encoder output lengths, of shape (N,). """ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) + x, x_lens = self.encoder_embed(x, x_lens) # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - src_key_padding_mask = make_pad_mask(x_lens) + + src_key_padding_mask = make_pad_mask(x_lens) # (N, T) + specaug_mask = specaug_mask[:, ::2] + assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 + specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) @@ -547,5 +554,13 @@ def forward_reconstruction_loss(self, # helps to down-weight the effect of very silent silences. loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, reduction='none', beta=1.0) + + # masking. if it's different from the next item on both the frequency dim + # and the time dim, it means we are in neither a time masked nor a frequency masked + # position. + mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), + log_mels != torch.roll(log_mels, 1, dims=1)) + loss = loss * mask.to(loss.dtype) + loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. return loss diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 87bca224d4..7949e33b99 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -552,9 +552,6 @@ def mean_and_variance_norm(x): return x - if mask is not None: - mask = mask.to(x.dtype) - with torch.no_grad(): x_swapped = torch.roll(x, batch_size // 2, batch_dim) x_swapped = mean_and_variance_norm(x_swapped) @@ -570,6 +567,8 @@ def mean_and_variance_norm(x): logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") if mask is not None: + mask = mask.to(x.dtype) + mask = torch.roll(mask, batch_size // 2, batch_dim) loss = loss * mask return loss.sum() # we reduce with sum in what we return. diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a648a5eb3d..5e3e7ca60d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -231,6 +231,7 @@ def forward( x: Tensor, x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, + specaug_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Args: @@ -279,6 +280,11 @@ def truncate(x, downsampling_factor): if src_key_padding_mask is None else src_key_padding_mask[..., ::ds] ), + specaug_mask=( + None + if specaug_mask is None + else specaug_mask[..., ::ds] + ), attn_mask=(None if attn_mask is None else attn_mask[::ds, ::ds] @@ -759,6 +765,7 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + specaug_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -802,8 +809,16 @@ def forward( bypass = self.copy_bypass(bypass) src = torch.cat((src, bypass), dim=-1) - return src, self.predict_loss(src, (src_key_padding_mask.t().logical_not() - if src_key_padding_mask is not None else None)) + if src_key_padding_mask is not None and specaug_mask is not None: + mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) + elif src_key_padding_mask is not None: + mask = src_key_padding_mask.t().logical_not() + elif specaug_mask is not None: + mask = specaug_mask.t().logical_not() + else: + mask = None + + return src, self.predict_loss(src, mask) def streaming_forward( self, From 44896cc8ef95f5488ee22db3ba51fc27141b45b7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 May 2025 18:57:42 +0800 Subject: [PATCH 0367/1191] Revert prediction loss to the codebook type; but still use masking. --- egs/librispeech/ASR/zipformer/scaling.py | 58 ++++++++++++------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7949e33b99..65a2445b9c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -534,7 +534,7 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, t: float, +def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, batch_dim: int, name: str, mask: Optional[Tensor]) -> Tensor: batch_size = x.shape[batch_dim] @@ -553,21 +553,27 @@ def mean_and_variance_norm(x): with torch.no_grad(): - x_swapped = torch.roll(x, batch_size // 2, batch_dim) - x_swapped = mean_and_variance_norm(x_swapped) - rand = torch.randn_like(x_swapped) - x_interp = (t * x_swapped) + (1 - t) * rand - u_t = x_swapped - rand # reference "velocity" + # get the indexes. project, then mean-and-variance-norm, then + # take mx. + x_proj = torch.matmul(x, proj_weight.t()) + with torch.amp.autocast('cuda', enabled=False): + x_proj = mean_and_variance_norm(x_proj.to(torch.float)) + indexes = torch.max(x_proj, dim=-1)[1] - v_t = predictor(torch.cat((x_interp, x), dim=-1)) # predicted "velocity" - loss = ((u_t - v_t) ** 2).mean(dim=-1) + indexes = torch.roll(indexes, batch_size // 2, batch_dim) # predict index of the other masked copy. + x_pred = predictor(x) + logprobs = x_pred.log_softmax(dim=-1) + loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) if random.random() < 0.002: logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") if mask is not None: mask = mask.to(x.dtype) + # we also swap the mask over the two copies of the data; the mask goes with the thing that + # is predicted, not the thing we predict it from.. the idea being that we don't want to ask + # the model to predict masked portions of the time sequence. mask = torch.roll(mask, batch_size // 2, batch_dim) loss = loss * mask @@ -575,39 +581,33 @@ def mean_and_variance_norm(x): class PredictLoss(nn.Module): """ - Adds an auxiliary loss based on predicting the (noise - x) direction given two inputs: - the "x" value from the "other copy of the data", and the input ((1-t) * noise + t * x), - aas in flow matching. So a pretext task based on flow matching. - The "t" value is specified by the user, strictly between 0 and 1; smaller "t" means more noise, - larger "t" means closer to x. Smaller "t" will concentrate on the broader contours - of the distrbution. + Adds an auxiliary loss based on predicting the top-1 of randomized codebook + entries. (This relies on the CR-CTC structure of having two differently-masked + copies of the same utterance). Mean and variance normalization is applied prior to getting + the codebook indexes to keep this stable. """ def __init__(self, num_channels: int, batch_dim: int = 0, - t: float = 0.2, - num_repeats: int = 2): + codebook_size: int = 64): super().__init__() - num_hidden = max(1024, 2 * num_channels) - self.predictor = nn.Sequential(nn.Linear(2 * num_channels, num_hidden), - nn.LeakyReLU(), - nn.Linear(num_hidden, num_channels)) + scale = num_channels ** -0.5 + self.register_buffer('proj_weight', + scale * torch.randn(codebook_size, num_channels), + persistent=True) + num_hidden = max(1024, num_channels) + self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), + SwashR(), + nn.Linear(num_hidden, codebook_size)) self.batch_dim = batch_dim self.name = None # will be set from training code - self.t = t - self.num_repeats = num_repeats # to reduce variance of gradient (since this module draws random values). def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: # x is of shape (..., num_channels); mask is of shape (...), i.e. # it matches x except is missing the last dim. - # CAUTION: the part with "repeats" actually assumes that the time dim - # is dim zero and batch dim is dim 1. - assert self.batch_dim == 1 - r = self.num_repeats - return predict_loss(x.repeat(r, 1, 1), self.predictor, self.t, - self.batch_dim, self.name, - mask.repeat(r, 1) if mask is not None else None) / r + return predict_loss(x, self.predictor, self.proj_weight, + self.batch_dim, self.name, mask) From 4f32e816b4c31c8ba1fcb78de13e652fba36f903 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 9 May 2025 21:56:36 +0800 Subject: [PATCH 0368/1191] Bug fix re mask. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 65a2445b9c..82492d6f68 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -575,7 +575,7 @@ def mean_and_variance_norm(x): # is predicted, not the thing we predict it from.. the idea being that we don't want to ask # the model to predict masked portions of the time sequence. mask = torch.roll(mask, batch_size // 2, batch_dim) - loss = loss * mask + loss = loss * mask.unsqueeze(-1) return loss.sum() # we reduce with sum in what we return. From 2f6de60417e3e4fe33ed684ef3714da2323af96b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 12 May 2025 22:45:59 +0800 Subject: [PATCH 0369/1191] halved in_proj scale of self_attn, due to divergence. now 0.125. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 5e3e7ca60d..8111379b3d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1190,7 +1190,7 @@ def __init__( # it would be necessary to apply the scaling factor in the forward function. self.in_proj = ScaledLinear( embed_dim, in_proj_dim, - bias=True, initial_scale=0.25 * query_head_dim**-0.25 + bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) self.whiten_keys = Whiten( From 8441285f86c551e40818fd5030d6a248776c188d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 18 Jun 2025 11:17:07 +0800 Subject: [PATCH 0370/1191] Make diagnostics dump-able --- egs/librispeech/ASR/zipformer/train.py | 5 +- icefall/diagnostics.py | 231 ++----------------------- 2 files changed, 21 insertions(+), 215 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 53f679d8af..ef5bb0cb6d 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1574,7 +1574,10 @@ def remove_short_and_long_utt(c: Cut): ) if params.print_diagnostics: - diagnostic.print_diagnostics() + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") break save_checkpoint( diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 2cd350c07d..7a7b232ca5 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -215,12 +215,17 @@ def accumulate(self, x, class_name: Optional[str] = None): else: this_dim_stats[stats_type].append(TensorAndCount(stats, count)) - def print_diagnostics(self): - """Print diagnostics for each dimension of the tensor.""" + def print_diagnostics(self) -> dict: + """Print diagnostics for each dimension of the tensor. Returns a dict containing more specific stats, as tensors, that can be used for further + analysis if needed""" if self.stats is None: print(f"Warning: the stats of {self.name} is None.") return + + ans_dict = dict() + for dim, this_dim_stats in enumerate(self.stats): + ans_dict[dim] = dict() if "rms" in this_dim_stats and "value" in this_dim_stats: # produce "stddev" stats, which is centered RMS. rms_stats_list = this_dim_stats["rms"] @@ -279,6 +284,8 @@ def get_count(count): # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() + ans_dict[dim][stats_type] = stats + # if `summarize` we print percentiles of the stats; else, # we print out individual elements. summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( @@ -307,6 +314,7 @@ def get_count(count): # can be attributed to the mean of the distribution. norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" + mean = stats.mean().item() rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" @@ -324,184 +332,9 @@ def get_count(count): print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) + return ans_dict -class ScalarDiagnostic(object): - """This class is not directly used by the user, it is responsible for - collecting diagnostics for a single module (subclass of torch.nn.Module) that - represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc. - """ - - def __init__(self, opts: TensorDiagnosticOptions, name: str): - self.opts = opts - self.name = name - self.class_name = None # will assign in accumulate() - self.is_forward_pass = True - - self.tick_scale = None - - self.saved_inputs = [] - self.is_ok = True - - self.counts = None - self.sum_grad = None - self.sum_gradsq = None - self.sum_abs_grad = None - - def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): - """ - Called in forward pass. - """ - if not self.is_forward_pass: - # in case we did a forward pass without a backward pass, for some reason. - self.saved_inputs = [] - self.is_forward_pass = True - - if class_name is not None: - self.class_name = class_name - if not self.is_ok: - return - - limit = 10 - if len(self.saved_inputs) > limit: - print( - f"ERROR: forward pass called for this module over {limit} times with no backward pass. " - f" Will not accumulate scalar stats." - ) - self.is_ok = False - return - self.saved_inputs.append(x) - - def accumulate_output_grad(self, grad: Tensor): - if not self.is_ok: - return - if self.is_forward_pass: - self.is_forward_pass = False - - last_shape = ( - "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape - ) - if len(self.saved_inputs) == 0 or grad.shape != last_shape: - print( - f"ERROR: shape mismatch or no forward activation present when backward " - f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" - f", shape-of-last-saved-input={last_shape}" - ) - self.is_ok = False - return - - x = self.saved_inputs.pop() - self.process_input_and_grad(x, grad) - - def process_input_and_grad(self, x: Tensor, grad: Tensor): - assert x.shape == grad.shape - x = x.flatten() - grad = grad.flatten() - - num_ticks_per_side = 256 - - if self.tick_scale is None: - x_abs_sorted = x.abs().sort()[0] - # take the 98th percentile as the largest value we count separately. - index = int(x.numel() * 0.98) - self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) - - # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] - self.counts = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.long, device=x.device - ) - self.sum_grad = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.double, device=x.device - ) - # sum_gradsq is for getting error bars. - self.sum_gradsq = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.double, device=x.device - ) - self.sum_abs_grad = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.double, device=x.device - ) - - # this will round down. - x = (x / self.tick_scale).to(torch.long) - x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) - x = x + num_ticks_per_side - - self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) - self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) - self.sum_gradsq.index_add_( - dim=0, index=x, source=(grad * grad).to(torch.double) - ) - self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) - - def print_diagnostics(self): - """Print diagnostics.""" - if self.is_ok is False or self.counts is None: - print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") - return - - counts = self.counts.to("cpu") - sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) - sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) - sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) - - counts_cumsum = counts.cumsum(dim=0) - counts_tot = counts_cumsum[-1] - - # subdivide the distribution up into `num_bins` intervals for analysis, for greater - # statistical significance. each bin corresponds to multiple of the original 'tick' intervals. - num_bins = 20 - - # integer division - counts_per_bin = (counts_tot // num_bins) + 1 - bin_indexes = counts_cumsum // counts_per_bin - bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long) - - bin_counts = torch.zeros(num_bins, dtype=torch.long) - bin_counts.index_add_(dim=0, index=bin_indexes, source=counts) - bin_grad = torch.zeros(num_bins) - bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad) - bin_gradsq = torch.zeros(num_bins) - bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) - bin_abs_grad = torch.zeros(num_bins) - bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) - - avg_grad = bin_grad / bin_counts - avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() - - bin_boundary_counts = ( - torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin - ) - bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) - # boundaries are the "x" values between the bins, e.g. corresponding to the - # locations of percentiles of the distribution. - num_ticks_per_side = counts.numel() // 2 - bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale - - bin_grad = bin_grad / (bin_counts + 1) - bin_conf_interval = bin_gradsq.sqrt() / ( - bin_counts + 1 - ) # consider this a standard deviation. - # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, - # the gradients are. - bin_abs_grad = bin_abs_grad / (bin_counts + 1) - - bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20) - bin_conf = bin_grad / (bin_conf_interval + 1.0e-20) - - def tensor_to_str(x: Tensor): - x = ["%.2g" % f for f in x] - x = "[" + " ".join(x) + "]" - return x - - maybe_class_name = ( - f" type={self.class_name}," if self.class_name is not None else "" - ) - - print( - f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " - f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}" - ) - class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -526,10 +359,13 @@ def __getitem__(self, name: str): self.diagnostics[name] = T(self.opts, name) return self.diagnostics[name] - def print_diagnostics(self): - """Print diagnostics for each tensor.""" + def print_diagnostics(self) -> dict: + """Print diagnostics for each tensor. Returns dict with more detailed per-dimension info + that could be further analyzed.""" + ans = dict() for k in sorted(self.diagnostics.keys()): - self.diagnostics[k].print_diagnostics() + ans[k] = self.diagnostics[k].print_diagnostics() + return ans def get_class_name(module: nn.Module): @@ -626,39 +462,6 @@ def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) - if type(module).__name__ in [ - "Sigmoid", - "Tanh", - "ReLU", - "TanSwish", - "Swish", - "DoubleSwish", - "Swoosh", - ]: - # For these specific module types, accumulate some additional diagnostics - # that can help us improve the activation function. These require a lot of memory, - # to save the forward activations, so limit this to some select classes. - # Note: this will not work correctly for all model types. - def scalar_forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): - if isinstance(_input, tuple): - (_input,) = _input - assert isinstance(_input, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_input( - _input, class_name=get_class_name(_module) - ) - - def scalar_backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): - if isinstance(_output, tuple): - (_output,) = _output - assert isinstance(_output, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) - - module.register_forward_hook(scalar_forward_hook) - module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): From 9651926e54989f523af1f57df125fb597e8b97a2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 15:48:51 +0800 Subject: [PATCH 0371/1191] Change SpecAugment to ExpAugment, cleaner and simpler implementation. --- egs/librispeech/ASR/zipformer/exp_augment.py | 203 +++++++++++++++++++ egs/librispeech/ASR/zipformer/model.py | 3 +- egs/librispeech/ASR/zipformer/train.py | 30 +-- 3 files changed, 207 insertions(+), 29 deletions(-) create mode 100644 egs/librispeech/ASR/zipformer/exp_augment.py diff --git a/egs/librispeech/ASR/zipformer/exp_augment.py b/egs/librispeech/ASR/zipformer/exp_augment.py new file mode 100644 index 0000000000..9d4b66a676 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/exp_augment.py @@ -0,0 +1,203 @@ +import bisect +import math +import random +from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union + +import numpy as np +import torch +from torch import Tensor + +from lhotse import CutSet, FeatureExtractor +from lhotse.augmentation import dereverb_wpe_torch +from lhotse.utils import Pathlike + + + +class ExpAugment(torch.nn.Module): + """ + ExpAugment is a different, simpler implementation of the feature-masking and frame-masking + aspects of SpecAugment, without the time warping for now. + """ + def __init__( + self, + feature_mask_fraction: float = 0.16, # mean fraction masked, not max. + num_feature_masks: int = 2, + frame_mask_fraction: float = 0.23, # the mean, not max. + frame_mask_size: float = 50.0, # interpret as mean size of masked region, in frames. + p=0.9, # probability of doing augmentation, and if we do augmentation, of doing each type of augmentation + ): + super().__init__() + assert 0 <= p <= 1 + assert 0 <= feature_mask_fraction <= 1 + assert 0 <= frame_mask_fraction <= 1 + assert 0 < frame_mask_size + + self.feature_mask_fraction = feature_mask_fraction + self.num_feature_masks = num_feature_masks + self.frame_mask_fraction = frame_mask_fraction + self.frame_mask_size = frame_mask_size + self.p = p + + def forward( + self, + features: torch.Tensor, + lengths: Tensor, # can just set this to [ seq_len ] * batch_size + ) -> torch.Tensor: + """ + Computes ExpAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param lengths: an int tensor of shape ``(B,)``, giving the number + of frames 0 < f <= T for each sequence. Only used for masking when + computing the feature means. + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + B, T, F = features.shape + features = features.clone() + + if random.random() >= self.p: + return features + + # get feature means. + kwargs = {'device': features.device} + length_mask = (torch.arange(T, **kwargs)[:, None] < lengths[:, None, None]).to(features.dtype) + # length_mask: (B, T, 1), 1.0 for "kept" frames. + means = (features * length_mask).mean(dim=(1, 2), keepdim=True) + means = means * (B / lengths[:, None, None]) # compensate means for lengths less than B. + # means: (B, 1, 1) + + + features_unmasked = features + + features = self._mask_on_axis(features, means, axis=2, + num_regions=round(self.num_feature_masks/self.feature_mask_fraction), + num_masked_regions=self.num_feature_masks) + + + num_regions=max(3, round(T / self.frame_mask_size)) # at least 3 regions + num_masked_regions=max(1, round(num_regions * self.frame_mask_fraction)) + + features = self._mask_on_axis(features, means, axis=1, + num_regions=num_regions, + num_masked_regions=num_masked_regions) + + features = torch.where(torch.rand(B, 1, 1, **kwargs).expand_as(features) < self.p, + features, features_unmasked) + + return features + + def _mask_on_axis(self, + features: torch.Tensor, + means: torch.Tensor, + axis: int, + num_regions: int, + num_masked_regions: int): + """ + Mask ``features`` on a particular axis by replacing masked segments of that sequence with + ``means``. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param means: a batch of means of feature matrices with shape ``(B, 1, 1)`` + :param axis: the axis to mask on, i.e. 1 for time, 2 for frequency/feature. + :param num_regions: the number of regions to divide up the sequence-length, i.e. T or F, + on this axis + :param num_masked_regions: the number of those regions to mask. + """ + assert axis in [1,2] + # num_regions refers to regions including 'exterior' regions + num_regions = max(num_regions, (2 * num_masked_regions) + 1) + device = features.device + shape = list(features.shape) + B = shape[0] + N = shape[axis] # T or F + + # subtract num_regions; we'll later add torch.arange(num_regions + 1) to the rounded and sorted + # boundary edges to ensure all interior region boundaries are distinct and do not include 0 or N. + # + N_reduced = N - num_regions + + # 'boundaries' are the interior boundaries, i.e. the region edges except the beginning and + # end respectively of the first and last region. + boundaries = N_reduced * torch.rand(B, num_regions - 1, device=device) + boundaries = boundaries.round().to(torch.long) + boundaries = boundaries.sort(dim=1)[0] # make them consecutive. + # make sure the boundaries are all distinct from each other and + # also from N. + boundaries = boundaries + torch.arange(1, num_regions, device=device) + + # won't keep first or last region. and the numbering is in a numbering + # where the 1st region (bounded by start of sequence) is not counted, + # so the random numbers from the sort() will be between 0, 1, ... num_regions - 3. + kept_regions = torch.rand(B, num_regions - 2, device=device).sort(dim=1)[1] + region_numbers = kept_regions[:, :(2*num_masked_regions - 1)].sort(dim=1)[0] + + # example: + #torch.rand(3, 10).sort(dim=1)[1][:, :5].sort(dim=1)[0] + #tensor([[0, 1, 2, 5, 7], + # [1, 3, 6, 7, 8], + # [0, 1, 5, 8, 9]]) + + # of the not-discarded regions, take alternate regions. + region_numbers = region_numbers[:, ::2] + region_starts = torch.gather(boundaries, index=region_numbers, dim=1) + region_ends = torch.gather(boundaries[:, 1:], index=region_numbers, dim=1) + assert region_ends.shape == (B, num_masked_regions) + + + markers = torch.zeros(B, N, device=device, dtype=torch.long) + ones = torch.ones(*region_starts.shape, device=device, dtype=torch.long) + markers.scatter_(index=region_starts, dim=1, src=ones) + markers.scatter_(index=region_ends, dim=1, src=ones) + + cumsum = torch.cumsum(markers, dim=1) + is_masked = ((cumsum % 2) == 1) # (B, N), is True at spots to mask. + if axis == 1: + is_masked = is_masked.unsqueeze(-1) + else: + is_masked = is_masked.unsqueeze(1) + + return torch.where(is_masked.expand_as(features), means.expand_as(features), features) + + + def state_dict(self, **kwargs) -> Dict[str, Any]: + return dict( + feature_mask_fraction=self.feature_mask_fraction, + num_feature_masks=self.num_feature_masks, + frame_mask_fraction=self.frame_mask_fraction, + frame_masks_size=self.frame_mask_size, + p=self.p) + + + def load_state_dict(self, state_dict: Dict[str, Any]): + for name in ["feature_mask_fraction", "num_feature_masks", + "frame_mask_fraction", "frame_mask_size", "p"]: + if name in state_dict: + setattr(self, name, state_dict["name"]) + + + +def _test_exp_augment(): + exp_augment = ExpAugment(p=1.0, frame_mask_size=10) + B, T, F = 15, 100, 20 + #device = 'cuda' + device = 'cpu' + features = torch.randn(B, T, F, device=device) + lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) + #print("features=", features) + features = exp_augment(features, lengths) + + frame_is_masked = features[:, :, 0] == features[:, :, -1] + print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) + feature_is_masked = features[:, 0] == features[:, -1] + print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) + + +if __name__ == '__main__': + _test_exp_augment() diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 3fc7a5f06a..17784e9d68 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -23,7 +23,6 @@ import torch.nn as nn from torch import Tensor from encoder_interface import EncoderInterface -from lhotse.dataset import SpecAugment from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp @@ -369,7 +368,7 @@ def forward( lm_scale: float = 0.0, use_cr_ctc: bool = False, use_spec_aug: bool = False, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index ef5bb0cb6d..da4a8cc11d 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -71,7 +71,8 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut -from lhotse.dataset import SpecAugment +# from lhotse.dataset import SpecAugment +from exp_augment import ExpAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -549,13 +550,6 @@ def get_parser(): help="Prediction of random k-means after widest zipformer layer" ) - parser.add_argument( - "--time-mask-ratio", - type=float, - default=2.5, - help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", - ) - parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -827,24 +821,6 @@ def get_model(params: AttributeDict) -> nn.Module: return model -def get_spec_augment(params: AttributeDict) -> SpecAugment: - num_frame_masks = int(10 * params.time_mask_ratio) - max_frames_mask_fraction = 0.15 * params.time_mask_ratio - logging.info( - f"num_frame_masks: {num_frame_masks}, " - f"max_frames_mask_fraction: {max_frames_mask_fraction}" - ) - spec_augment = SpecAugment( - time_warp_factor=0, # Do time warping in model.py - num_frame_masks=num_frame_masks, # default: 10 - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 - ) - return spec_augment - - def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -1413,7 +1389,7 @@ def run(rank, world_size, args): if params.use_cr_ctc: assert params.use_ctc assert not params.enable_spec_aug # we will do spec_augment in model.py - spec_augment = get_spec_augment(params) + spec_augment = ExpAugment() else: spec_augment = None From d46a377719ca10210f9cd9d8b1c258571d3d70ca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 15:54:22 +0800 Subject: [PATCH 0372/1191] Bug fix re typing --- egs/librispeech/ASR/zipformer/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index da4a8cc11d..0b6a4339ef 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -943,7 +943,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -1104,7 +1104,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1132,7 +1132,7 @@ def train_one_epoch( scaler: The scaler used for mix precision training. spec_augment: - The SpecAugment instance used only when use_cr_ctc is True. + The SpecAugment (or similar) instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1612,7 +1612,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, ): from lhotse.dataset import find_pessimistic_batches From 7c2bb5f5d02befd88a4acdcb8a21f1c78965c308 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 16:08:23 +0800 Subject: [PATCH 0373/1191] Couple bug fixes with spec augment --- egs/librispeech/ASR/zipformer/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 17784e9d68..d995bdaa8d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -429,7 +429,7 @@ def forward( if use_cr_ctc: assert self.use_ctc if use_spec_aug: - assert spec_augment is not None and spec_augment.time_warp_factor < 1 + assert spec_augment is not None # Apply time warping before input duplicating assert supervision_segments is not None x = time_warp( @@ -438,7 +438,7 @@ def forward( supervision_segments=supervision_segments, ) # Independently apply frequency masking and time masking to the two copies - x = spec_augment(x.repeat(2, 1, 1)) + x = spec_augment(x.repeat(2, 1, 1), x_lens.to(x.device)) else: x = x.repeat(2, 1, 1) x_lens = x_lens.repeat(2) From c390eeec0d9614e24daeee7412f6eec381104520 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 16:41:45 +0800 Subject: [PATCH 0374/1191] Bug fix --- egs/librispeech/ASR/zipformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index d995bdaa8d..9d54986f03 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -438,7 +438,7 @@ def forward( supervision_segments=supervision_segments, ) # Independently apply frequency masking and time masking to the two copies - x = spec_augment(x.repeat(2, 1, 1), x_lens.to(x.device)) + x = spec_augment(x.repeat(2, 1, 1), x_lens.repeat(2).to(x.device)) else: x = x.repeat(2, 1, 1) x_lens = x_lens.repeat(2) From a8e4056952352ec036b2c65dd404292b9771c6be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 17:22:08 +0800 Subject: [PATCH 0375/1191] Move exp_augment.py, and bug fix regarding p. --- egs/librispeech/ASR/zipformer/train.py | 2 +- icefall/__init__.py | 2 +- .../ASR/zipformer => icefall}/exp_augment.py | 10 ---------- 3 files changed, 2 insertions(+), 12 deletions(-) rename {egs/librispeech/ASR/zipformer => icefall}/exp_augment.py (96%) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 0b6a4339ef..367e06da83 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -72,7 +72,6 @@ from joiner import Joiner from lhotse.cut import Cut # from lhotse.dataset import SpecAugment -from exp_augment import ExpAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -95,6 +94,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error +from icefall.exp_augment import ExpAugment from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, diff --git a/icefall/__init__.py b/icefall/__init__.py index b1e4313e9b..831d66f0a1 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,6 +1,6 @@ # isort:skip_file -from . import checkpoint, decode, dist, env, utils +from . import checkpoint, decode, dist, env, utils, exp_augment from .byte_utils import ( byte_decode, diff --git a/egs/librispeech/ASR/zipformer/exp_augment.py b/icefall/exp_augment.py similarity index 96% rename from egs/librispeech/ASR/zipformer/exp_augment.py rename to icefall/exp_augment.py index 9d4b66a676..9f9a98c404 100644 --- a/egs/librispeech/ASR/zipformer/exp_augment.py +++ b/icefall/exp_augment.py @@ -1,15 +1,7 @@ -import bisect -import math import random from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union -import numpy as np import torch -from torch import Tensor - -from lhotse import CutSet, FeatureExtractor -from lhotse.augmentation import dereverb_wpe_torch -from lhotse.utils import Pathlike @@ -62,8 +54,6 @@ def forward( B, T, F = features.shape features = features.clone() - if random.random() >= self.p: - return features # get feature means. kwargs = {'device': features.device} From c4ec03be191c1df956d8519277c787b5827541b4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 17:25:20 +0800 Subject: [PATCH 0376/1191] Bug fix re Tensor type --- icefall/exp_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 9f9a98c404..67d98a1004 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -33,7 +33,7 @@ def __init__( def forward( self, features: torch.Tensor, - lengths: Tensor, # can just set this to [ seq_len ] * batch_size + lengths: torch.Tensor, # can just set this to [ seq_len ] * batch_size ) -> torch.Tensor: """ Computes ExpAugment for a batch of feature matrices. From b91bfedfd80e528030bc5cc8941d870aa242ecdf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 19:03:19 +0800 Subject: [PATCH 0377/1191] fix state_dict issue --- icefall/exp_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 67d98a1004..c4bfed9984 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -161,7 +161,7 @@ def state_dict(self, **kwargs) -> Dict[str, Any]: feature_mask_fraction=self.feature_mask_fraction, num_feature_masks=self.num_feature_masks, frame_mask_fraction=self.frame_mask_fraction, - frame_masks_size=self.frame_mask_size, + frame_mask_size=self.frame_mask_size, p=self.p) From 8bb257d203699d8fe62f2872ee6ce2a7188d1b15 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 19:07:56 +0800 Subject: [PATCH 0378/1191] Get rid of lengths in exp_augment, use overall mean. --- egs/librispeech/ASR/zipformer/model.py | 2 +- icefall/exp_augment.py | 27 ++++++++++---------------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 9d54986f03..db4230a1a6 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -438,7 +438,7 @@ def forward( supervision_segments=supervision_segments, ) # Independently apply frequency masking and time masking to the two copies - x = spec_augment(x.repeat(2, 1, 1), x_lens.repeat(2).to(x.device)) + x = spec_augment(x.repeat(2, 1, 1)) else: x = x.repeat(2, 1, 1) x_lens = x_lens.repeat(2) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index c4bfed9984..15cc8745a1 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -33,7 +33,6 @@ def __init__( def forward( self, features: torch.Tensor, - lengths: torch.Tensor, # can just set this to [ seq_len ] * batch_size ) -> torch.Tensor: """ Computes ExpAugment for a batch of feature matrices. @@ -43,9 +42,7 @@ def forward( only to selected areas of the input. The format of this input is described below. :param features: a batch of feature matrices with shape ``(B, T, F)``. - :param lengths: an int tensor of shape ``(B,)``, giving the number - of frames 0 < f <= T for each sequence. Only used for masking when - computing the feature means. + :return: an augmented tensor of shape ``(B, T, F)``. """ assert len(features.shape) == 3, ( @@ -57,16 +54,13 @@ def forward( # get feature means. kwargs = {'device': features.device} - length_mask = (torch.arange(T, **kwargs)[:, None] < lengths[:, None, None]).to(features.dtype) - # length_mask: (B, T, 1), 1.0 for "kept" frames. - means = (features * length_mask).mean(dim=(1, 2), keepdim=True) - means = means * (B / lengths[:, None, None]) # compensate means for lengths less than B. - # means: (B, 1, 1) + + mean = features.mean() features_unmasked = features - features = self._mask_on_axis(features, means, axis=2, + features = self._mask_on_axis(features, mean, axis=2, num_regions=round(self.num_feature_masks/self.feature_mask_fraction), num_masked_regions=self.num_feature_masks) @@ -74,7 +68,7 @@ def forward( num_regions=max(3, round(T / self.frame_mask_size)) # at least 3 regions num_masked_regions=max(1, round(num_regions * self.frame_mask_fraction)) - features = self._mask_on_axis(features, means, axis=1, + features = self._mask_on_axis(features, mean, axis=1, num_regions=num_regions, num_masked_regions=num_masked_regions) @@ -85,16 +79,16 @@ def forward( def _mask_on_axis(self, features: torch.Tensor, - means: torch.Tensor, + mean: torch.Tensor, axis: int, num_regions: int, num_masked_regions: int): """ Mask ``features`` on a particular axis by replacing masked segments of that sequence with - ``means``. + ``mean``. :param features: a batch of feature matrices with shape ``(B, T, F)``. - :param means: a batch of means of feature matrices with shape ``(B, 1, 1)`` + :param mean: a batch of means of feature matrices with shape ``(B, 1, 1)`` :param axis: the axis to mask on, i.e. 1 for time, 2 for frequency/feature. :param num_regions: the number of regions to divide up the sequence-length, i.e. T or F, on this axis @@ -153,7 +147,7 @@ def _mask_on_axis(self, else: is_masked = is_masked.unsqueeze(1) - return torch.where(is_masked.expand_as(features), means.expand_as(features), features) + return torch.where(is_masked.expand_as(features), mean[None, None, None].expand_as(features), features) def state_dict(self, **kwargs) -> Dict[str, Any]: @@ -179,9 +173,8 @@ def _test_exp_augment(): #device = 'cuda' device = 'cpu' features = torch.randn(B, T, F, device=device) - lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) #print("features=", features) - features = exp_augment(features, lengths) + features = exp_augment(features) frame_is_masked = features[:, :, 0] == features[:, :, -1] print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) From 155187c63d4a3ee80b81acd0490e6d5887313b7a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 22:37:20 +0800 Subject: [PATCH 0379/1191] Add comparison with SpecAugment; and do simplification/refactor in exp_augment.py --- icefall/exp_augment.py | 137 ++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 64 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index c4bfed9984..70aa8a0206 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -12,9 +12,9 @@ class ExpAugment(torch.nn.Module): """ def __init__( self, - feature_mask_fraction: float = 0.16, # mean fraction masked, not max. + feature_mask_fraction: float = 0.26, # mean fraction masked, not max. num_feature_masks: int = 2, - frame_mask_fraction: float = 0.23, # the mean, not max. + frame_mask_fraction: float = 0.21, # the mean, not max. frame_mask_size: float = 50.0, # interpret as mean size of masked region, in frames. p=0.9, # probability of doing augmentation, and if we do augmentation, of doing each type of augmentation ): @@ -63,20 +63,17 @@ def forward( means = means * (B / lengths[:, None, None]) # compensate means for lengths less than B. # means: (B, 1, 1) - features_unmasked = features features = self._mask_on_axis(features, means, axis=2, - num_regions=round(self.num_feature_masks/self.feature_mask_fraction), - num_masked_regions=self.num_feature_masks) - + masked_fraction=self.feature_mask_fraction, + num_masks=self.num_feature_masks) - num_regions=max(3, round(T / self.frame_mask_size)) # at least 3 regions - num_masked_regions=max(1, round(num_regions * self.frame_mask_fraction)) + num_masks = max(1, round((T * self.frame_mask_fraction) / self.frame_mask_size)) features = self._mask_on_axis(features, means, axis=1, - num_regions=num_regions, - num_masked_regions=num_masked_regions) + masked_fraction=self.frame_mask_fraction, + num_masks=num_masks) features = torch.where(torch.rand(B, 1, 1, **kwargs).expand_as(features) < self.p, features, features_unmasked) @@ -87,8 +84,8 @@ def _mask_on_axis(self, features: torch.Tensor, means: torch.Tensor, axis: int, - num_regions: int, - num_masked_regions: int): + masked_fraction: float, + num_masks: int) -> torch.Tensor: """ Mask ``features`` on a particular axis by replacing masked segments of that sequence with ``means``. @@ -96,56 +93,38 @@ def _mask_on_axis(self, :param features: a batch of feature matrices with shape ``(B, T, F)``. :param means: a batch of means of feature matrices with shape ``(B, 1, 1)`` :param axis: the axis to mask on, i.e. 1 for time, 2 for frequency/feature. - :param num_regions: the number of regions to divide up the sequence-length, i.e. T or F, - on this axis - :param num_masked_regions: the number of those regions to mask. + :param masked_fraction: the fraction of the data to mask, in expectation. + :param num_masks: the number of masked regions. """ assert axis in [1,2] # num_regions refers to regions including 'exterior' regions - num_regions = max(num_regions, (2 * num_masked_regions) + 1) device = features.device shape = list(features.shape) B = shape[0] N = shape[axis] # T or F - # subtract num_regions; we'll later add torch.arange(num_regions + 1) to the rounded and sorted - # boundary edges to ensure all interior region boundaries are distinct and do not include 0 or N. - # - N_reduced = N - num_regions - - # 'boundaries' are the interior boundaries, i.e. the region edges except the beginning and - # end respectively of the first and last region. - boundaries = N_reduced * torch.rand(B, num_regions - 1, device=device) - boundaries = boundaries.round().to(torch.long) - boundaries = boundaries.sort(dim=1)[0] # make them consecutive. - # make sure the boundaries are all distinct from each other and - # also from N. - boundaries = boundaries + torch.arange(1, num_regions, device=device) - - # won't keep first or last region. and the numbering is in a numbering - # where the 1st region (bounded by start of sequence) is not counted, - # so the random numbers from the sort() will be between 0, 1, ... num_regions - 3. - kept_regions = torch.rand(B, num_regions - 2, device=device).sort(dim=1)[1] - region_numbers = kept_regions[:, :(2*num_masked_regions - 1)].sort(dim=1)[0] - - # example: - #torch.rand(3, 10).sort(dim=1)[1][:, :5].sort(dim=1)[0] - #tensor([[0, 1, 2, 5, 7], - # [1, 3, 6, 7, 8], - # [0, 1, 5, 8, 9]]) - - # of the not-discarded regions, take alternate regions. - region_numbers = region_numbers[:, ::2] - region_starts = torch.gather(boundaries, index=region_numbers, dim=1) - region_ends = torch.gather(boundaries[:, 1:], index=region_numbers, dim=1) - assert region_ends.shape == (B, num_masked_regions) + def sample_from_exponential(*shape): + eps=1.0e-20 + return -(torch.rand(*shape, device=device) + eps).log() - markers = torch.zeros(B, N, device=device, dtype=torch.long) - ones = torch.ones(*region_starts.shape, device=device, dtype=torch.long) - markers.scatter_(index=region_starts, dim=1, src=ones) - markers.scatter_(index=region_ends, dim=1, src=ones) + mask_lengths = sample_from_exponential(B, num_masks) * masked_fraction + unmasked_lengths = sample_from_exponential(B, num_masks + 1) * ((1. - masked_fraction) * num_masks / (num_masks + 1)) + + lengths = torch.empty(B, 2 * num_masks + 1, device=device) + lengths[:, 1::2] = mask_lengths + lengths[:, 0::2] = unmasked_lengths + for _ in range(2): + lengths = lengths * (N / lengths.sum(1, keepdim=True)) + lengths = lengths.round().clamp_(min=1).to(torch.long) + + positions = lengths.cumsum(dim=1) + positions = positions[:-1].clamp_(max=N-1) # don't need the last position, which should be close to N. + + ones = torch.ones(*positions.shape, device=device, dtype=torch.long) + markers = torch.zeros(B, N, device=device, dtype=torch.long) + markers.scatter_(index=positions, dim=1, src=ones) cumsum = torch.cumsum(markers, dim=1) is_masked = ((cumsum % 2) == 1) # (B, N), is True at spots to mask. if axis == 1: @@ -174,20 +153,50 @@ def load_state_dict(self, state_dict: Dict[str, Any]): def _test_exp_augment(): - exp_augment = ExpAugment(p=1.0, frame_mask_size=10) - B, T, F = 15, 100, 20 - #device = 'cuda' - device = 'cpu' - features = torch.randn(B, T, F, device=device) - lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) - #print("features=", features) - features = exp_augment(features, lengths) - - frame_is_masked = features[:, :, 0] == features[:, :, -1] - print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) - feature_is_masked = features[:, 0] == features[:, -1] - print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) + for n in [ 0, 1 ]: + #device = 'cuda' + B, T, F = 300, 600, 80 + device = 'cpu' + if n == 0: + exp_augment = ExpAugment(p=1.0, frame_mask_size=10) + else: + from lhotse.dataset import SpecAugment + time_mask_ratio = 3.5 + num_frame_masks = int(10 * time_mask_ratio) + max_frames_mask_fraction = 0.15 * time_mask_ratio + print( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + supervision_segments = torch.stack(( + torch.arange(B, device=device), # sequence_idx + torch.zeros(B, device=device, dtype=torch.long), # start_frame + T * torch.ones(B, device=device, dtype=torch.long) # num_frames + ), dim=-1) + exp_augment = lambda x, lengths: spec_augment(x, supervision_segments) + + features = torch.randn(B, T, F, device=device) + lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) + #print("features=", features) + features = exp_augment(features, lengths) + + frame_is_masked = features[:, :, 0] == features[:, :, -1] + print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) + feature_is_masked = features[:, 0] == features[:, -1] + print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) + + + +# from lhotse.dataset import SpecAugment if __name__ == '__main__': _test_exp_augment() From be18f728660ad5656dd29a4035efe511ea33c1ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 19:28:33 +0800 Subject: [PATCH 0380/1191] Use un-augmented input as reference for reconstruction loss. --- egs/librispeech/ASR/zipformer/model.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 9d54986f03..aac5836a19 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -438,9 +438,12 @@ def forward( supervision_segments=supervision_segments, ) # Independently apply frequency masking and time masking to the two copies - x = spec_augment(x.repeat(2, 1, 1), x_lens.repeat(2).to(x.device)) + + x_no_specaug = x.repeat(2, 1, 1) + x = spec_augment(x_no_spcaug, x_lens.repeat(2).to(x.device)) else: - x = x.repeat(2, 1, 1) + x_no_specaug = x.repeat(2, 1, 1) + x = x_no_specaug x_lens = x_lens.repeat(2) y = k2.ragged.cat([y, y], axis=0) @@ -504,9 +507,9 @@ def forward( else: attention_decoder_loss = torch.empty(0) - reconstruction_loss = self.forward_reconstruction_loss(x, encoder_out, - encoder_out_lens, - use_cr_ctc) + reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) + if use_cr_ctc: reconstruction_loss = reconstruction_loss * 0.5 @@ -516,11 +519,9 @@ def forward( def forward_reconstruction_loss(self, log_mels: Tensor, encoder_out: Tensor, - encoder_out_lens: Tensor, - use_cr_ctc: bool): + encoder_out_lens: Tensor): """ - Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If - use_cr_ctc then we swap the first and second halves of the batch. + Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. Args: log_mels: log-mel features of shape (batch_size, T, num_mels) @@ -528,8 +529,6 @@ def forward_reconstruction_loss(self, """ batch_size = log_mels.shape[0] num_mels = log_mels.shape[2] - if use_cr_ctc: - log_mels = torch.roll(log_mels, batch_size // 2, dims=0) pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) T_embed = pred_mels.shape[1] From 6591a046a0071a069d76d7949107dd56487cdd34 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 1 Jul 2025 23:17:39 +0800 Subject: [PATCH 0381/1191] Bug fix --- egs/librispeech/ASR/zipformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 8292d7df60..0dcbe18900 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -440,7 +440,7 @@ def forward( # Independently apply frequency masking and time masking to the two copies x_no_specaug = x.repeat(2, 1, 1) - x = spec_augment(x_no_spcaug) + x = spec_augment(x_no_specaug) else: x_no_specaug = x.repeat(2, 1, 1) x = x_no_specaug From 58c63be8854f2f443b17574da0819905f9180861 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Jul 2025 16:13:12 +0800 Subject: [PATCH 0382/1191] Use sum of two exponentials, not one exponential, for all samples. --- icefall/exp_augment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 785c29c71e..321408feb1 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -99,7 +99,9 @@ def _mask_on_axis(self, def sample_from_exponential(*shape): eps=1.0e-20 - return -(torch.rand(*shape, device=device) + eps).log() + # Modify to sample from mean of two exponential distributions. + a = -(torch.rand(2, *shape, device=device) + eps).log() + return a.mean(dim=0) mask_lengths = sample_from_exponential(B, num_masks) * masked_fraction From 54242a9f3e631edbe552daa9d0d455e115cc49a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Jul 2025 15:53:06 +0800 Subject: [PATCH 0383/1191] Make the implementation of ExpAugment almost exactly like SpecAug. --- icefall/exp_augment.py | 118 +++++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 47 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 321408feb1..62ada57076 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -12,22 +12,23 @@ class ExpAugment(torch.nn.Module): """ def __init__( self, - feature_mask_fraction: float = 0.26, # mean fraction masked, not max. + max_feature_mask_fraction: float = 0.675, # max fraction that can possibly be masked num_feature_masks: int = 2, - frame_mask_fraction: float = 0.21, # the mean, not max. - frame_mask_size: float = 50.0, # interpret as mean size of masked region, in frames. - p=0.9, # probability of doing augmentation, and if we do augmentation, of doing each type of augmentation + max_frame_mask_fraction: float = 0.525, + max_frame_mask_size: float = 100, # max size in frames of temporal masks. + p=0.9, # probability of doing augmentation ): super().__init__() assert 0 <= p <= 1 - assert 0 <= feature_mask_fraction <= 1 - assert 0 <= frame_mask_fraction <= 1 - assert 0 < frame_mask_size + assert 0 <= max_feature_mask_fraction <= 1 + assert 0 <= max_frame_mask_fraction <= 1 + assert 0 <= max_frame_mask_size + assert 0 <= num_feature_masks - self.feature_mask_fraction = feature_mask_fraction + self.max_feature_mask_fraction = max_feature_mask_fraction self.num_feature_masks = num_feature_masks - self.frame_mask_fraction = frame_mask_fraction - self.frame_mask_size = frame_mask_size + self.max_frame_mask_fraction = max_frame_mask_fraction + self.max_frame_mask_size = max_frame_mask_size self.p = p def forward( @@ -59,15 +60,20 @@ def forward( features_unmasked = features - features = self._mask_on_axis(features, mean, axis=2, - masked_fraction=self.feature_mask_fraction, - num_masks=self.num_feature_masks) + if self.num_feature_masks > 0: + num_masks = self.num_feature_masks + max_mask_size = F * self.max_feature_mask_fraction / num_masks + features = self._mask_on_axis(features, mean, axis=2, + max_mask_size=max_mask_size, + num_masks=num_masks) - num_masks = max(1, round((T * self.frame_mask_fraction) / self.frame_mask_size)) - features = self._mask_on_axis(features, mean, axis=1, - masked_fraction=self.frame_mask_fraction, - num_masks=num_masks) + if self.max_frame_mask_fraction > 0: + num_masks = max(1, round((T * self.max_frame_mask_fraction) / self.max_frame_mask_size)) + max_mask_size = T * self.max_frame_mask_fraction / num_masks + features = self._mask_on_axis(features, mean, axis=1, + max_mask_size=max_mask_size, + num_masks=num_masks) features = torch.where(torch.rand(B, 1, 1, **kwargs).expand_as(features) < self.p, features, features_unmasked) @@ -78,7 +84,7 @@ def _mask_on_axis(self, features: torch.Tensor, mean: torch.Tensor, axis: int, - masked_fraction: float, + max_mask_size: float, num_masks: int) -> torch.Tensor: """ Mask ``features`` on a particular axis by replacing masked segments of that sequence with @@ -95,34 +101,52 @@ def _mask_on_axis(self, device = features.device shape = list(features.shape) B = shape[0] + M = num_masks N = shape[axis] # T or F - def sample_from_exponential(*shape): - eps=1.0e-20 - # Modify to sample from mean of two exponential distributions. - a = -(torch.rand(2, *shape, device=device) + eps).log() - return a.mean(dim=0) + mask_lengths = torch.rand(B, num_masks, device=device) * max_mask_size + mask_starts = torch.rand(B, num_masks, device=device) * (N - mask_lengths) + mask_ends = mask_starts + mask_lengths - mask_lengths = sample_from_exponential(B, num_masks) * masked_fraction - unmasked_lengths = sample_from_exponential(B, num_masks + 1) * ((1. - masked_fraction) * num_masks / (num_masks + 1)) + mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) - lengths = torch.empty(B, 2 * num_masks + 1, device=device) - lengths[:, 1::2] = mask_lengths - lengths[:, 0::2] = unmasked_lengths - for _ in range(2): - lengths = lengths * (N / lengths.sum(1, keepdim=True)) - lengths = lengths.round().clamp_(min=1).to(torch.long) + # round down to next integer. + mask_boundaries = mask_boundaries.to(torch.long).clamp(min=0, max=N-1) - positions = lengths.cumsum(dim=1) - positions = positions[:-1].clamp_(max=N-1) # don't need the last position, which should be close to N. + # _masks: (B, M, N) + _masks = torch.logical_and(torch.arange(N) >= mask_starts[..., None], + torch.arange(N) <= mask_ends[..., None]).to(torch.float) + _masks = torch.sum(_masks, dim=1).clamp(max=1) - ones = torch.ones(*positions.shape, device=device, dtype=torch.long) - markers = torch.zeros(B, N, device=device, dtype=torch.long) - markers.scatter_(index=positions, dim=1, src=ones) - cumsum = torch.cumsum(markers, dim=1) - is_masked = ((cumsum % 2) == 1) # (B, N), is True at spots to mask. + is_mask_start = torch.cat((torch.ones(B, M, dtype=torch.bool, device=device), + torch.zeros(B, M, dtype=torch.bool, device=device)), + dim=1) + + mask_boundaries, indexes = mask_boundaries.sort(dim=1) + is_mask_start = torch.gather(is_mask_start, dim=1, index=indexes) + not_mask_start = torch.logical_not(is_mask_start) + + # is_not_repeat is 1 if this element of mask_boundaries is not a repeat of the same boundary + # type as the previous boundary, i.e. mask start or mask end. + + keep_boundary = torch.ones(B, 2 * M, device=device, dtype=torch.bool) + # the following says: set to False all elements of keep_boundary where both this and the previous + # element is a mask start. I.e. remove redundant mask-starts corresponding to overlapping masks. + keep_boundary[:, 1:][torch.logical_and(is_mask_start[:, :-1], is_mask_start[:, 1:])] = False + # the following says: set to False all elements of keep_boundary where both this and the next + # element are mask ends. I.e. remove redundant mask-ends corresponding to overlapping masks. + keep_boundary[:, :-1][torch.logical_and(not_mask_start[:, :-1], not_mask_start[:, 1:])] = False + + keep_boundary = keep_boundary.to(dtype=torch.long) + cumsum = torch.zeros(B, N, device=device, dtype=torch.long) + cumsum.scatter_add_(index=mask_boundaries, dim=1, src=keep_boundary) + + + cumsum = torch.cumsum(cumsum, dim=1) + + is_masked = (cumsum % 2) == 1 # (B, N), is True at spots to mask. if axis == 1: is_masked = is_masked.unsqueeze(-1) else: @@ -132,17 +156,16 @@ def sample_from_exponential(*shape): def state_dict(self, **kwargs) -> Dict[str, Any]: - return dict( - feature_mask_fraction=self.feature_mask_fraction, - num_feature_masks=self.num_feature_masks, - frame_mask_fraction=self.frame_mask_fraction, - frame_mask_size=self.frame_mask_size, - p=self.p) + + dict = { } + for name in ["max_feature_mask_fraction", "num_feature_masks", + "max_frame_mask_fraction", "max_frame_mask_size", "p"]: + dict[name] = getattr(self, name) def load_state_dict(self, state_dict: Dict[str, Any]): - for name in ["feature_mask_fraction", "num_feature_masks", - "frame_mask_fraction", "frame_mask_size", "p"]: + for name in ["max_feature_mask_fraction", "num_feature_masks", + "max_frame_mask_fraction", "max_frame_mask_size", "p"]: if name in state_dict: setattr(self, name, state_dict["name"]) @@ -155,7 +178,7 @@ def _test_exp_augment(): device = 'cpu' if n == 0: - exp_augment = ExpAugment(p=1.0, frame_mask_size=10) + exp_augment = ExpAugment(p=1.0) #, max_frame_mask_size=2.0, max_frame_mask_fraction=0.02) else: from lhotse.dataset import SpecAugment time_mask_ratio = 3.5 @@ -172,6 +195,7 @@ def _test_exp_augment(): num_feature_masks=2, frames_mask_size=100, max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + p=1.0, ) supervision_segments = torch.stack(( torch.arange(B, device=device), # sequence_idx From 0c52177579d0c9df569e30c36886d8b7bc7ab103 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Jul 2025 16:03:02 +0800 Subject: [PATCH 0384/1191] Bug fix --- icefall/exp_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 62ada57076..2ba2f123b2 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -116,8 +116,8 @@ def _mask_on_axis(self, # _masks: (B, M, N) - _masks = torch.logical_and(torch.arange(N) >= mask_starts[..., None], - torch.arange(N) <= mask_ends[..., None]).to(torch.float) + _masks = torch.logical_and(torch.arange(N, device=device) >= mask_starts[..., None], + torch.arange(N, device=device) <= mask_ends[..., None]).to(torch.float) _masks = torch.sum(_masks, dim=1).clamp(max=1) is_mask_start = torch.cat((torch.ones(B, M, dtype=torch.bool, device=device), From 831dae2ebe2f569de6c5e48adeeba8d817e2c869 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Jul 2025 11:23:25 +0800 Subject: [PATCH 0385/1191] Make the masks non-overlapping. --- icefall/exp_augment.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 2ba2f123b2..d464ec1700 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -104,10 +104,7 @@ def _mask_on_axis(self, M = num_masks N = shape[axis] # T or F - mask_lengths = torch.rand(B, num_masks, device=device) * max_mask_size - - mask_starts = torch.rand(B, num_masks, device=device) * (N - mask_lengths) - mask_ends = mask_starts + mask_lengths + mask_starts, mask_ends = self._sample_mask_starts_and_ends(B, N, num_masks, max_mask_size, device) mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) @@ -155,6 +152,40 @@ def _mask_on_axis(self, return torch.where(is_masked.expand_as(features), mean[None, None, None].expand_as(features), features) + def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_size, device) -> Tuple[Tuple,Tuple]: + # compute the start and end positions of masked regions. this will select mask positions + # that do not overlap. Return: (mask_starts, mask_ends) + + mask_lengths = torch.rand(batch_size, num_masks, device=device) * max_mask_size + mask_tot_len = mask_lengths.sum(dim=1, keepdim=True) # (batch_size, 1) + padding_tot_len = seq_len - mask_tot_len # (batch_size, 1) + eps = 1.0e-20 + + # get padding lengths by randomly placing dividers on the line of length "padding_tot_len" + # these "padding_positions" are not absolute position on the line from 0 to seq_len, + # but positions on the line from 0 to "padding_tot_len" which divides up the length + # we need to pad. + num_pads = num_masks + 1 + padding_positions = torch.rand(batch_size, num_pads - 1, device=device) * padding_tot_len + padding_positions = padding_positions.sort(dim=1)[0] + zero = torch.zeros(batch_size, 1, device=device) + padding_positions = torch.cat((zero, padding_positions, padding_tot_len), dim=1) + padding_lengths = padding_positions[:, 1:] - padding_positions[:, :-1] + + lengths = torch.empty(batch_size, num_masks * 2 + 1, device=device) + lengths[:, 1::2] = mask_lengths + lengths[:, 0::2] = padding_lengths + + positions = torch.cumsum(lengths, dim=1) + # last element of 'positions' should be seq_len + assert torch.all((positions[:, -1] - seq_len).abs() < 0.0001 * seq_len) + + # positions does not have a leading zero, cumsum is inclusive; but do not treat final `seq_len` as a mask start position. + mask_starts = positions[:, 0:-1:2] + mask_ends = positions[:, 1::2] + assert mask_starts.shape == (batch_size, num_masks) and mask_ends.shape == (batch_size, num_masks) + return mask_starts, mask_ends + def state_dict(self, **kwargs) -> Dict[str, Any]: dict = { } From de997ce0541fe3978aee8541b89662499ed36698 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Jul 2025 12:12:01 +0800 Subject: [PATCH 0386/1191] Roll half of the masked regions between the two copies of the same data, to make masks anti-correlated between them. --- icefall/exp_augment.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index d464ec1700..b57bf9120e 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -106,6 +106,13 @@ def _mask_on_axis(self, mask_starts, mask_ends = self._sample_mask_starts_and_ends(B, N, num_masks, max_mask_size, device) + # roll half or the mask_starts and mask_ends between the first and second + # halves of the batch. this is intended to help CR-CTC, by making the + # masked regions of the two augmented versions of the same data anti-correlated. + mask_starts[:, ::2] = mask_starts[:, ::2].roll(batch_size // 2, dim=0) + mask_ends[:, ::2] = mask_ends[:, ::2].roll(batch_size // 2, dim=0) + + mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) # round down to next integer. @@ -184,6 +191,7 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ mask_starts = positions[:, 0:-1:2] mask_ends = positions[:, 1::2] assert mask_starts.shape == (batch_size, num_masks) and mask_ends.shape == (batch_size, num_masks) + return mask_starts, mask_ends def state_dict(self, **kwargs) -> Dict[str, Any]: From 0eaddb98a26526d4a534bca65fe37a1664dff1a4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Jul 2025 12:44:44 +0800 Subject: [PATCH 0387/1191] Copy zapformer directory from 857 so both jobs can run. --- egs/librispeech/ASR/zapformer/.gitignore | 1 + .../ASR/zapformer/asr_datamodule.py | 454 +++++ .../ASR/zapformer/attention_decoder.py | 1 + egs/librispeech/ASR/zapformer/beam_search.py | 1 + egs/librispeech/ASR/zapformer/ctc_decode.py | 1 + egs/librispeech/ASR/zapformer/decode.py | 1 + .../ASR/zapformer/decode_gigaspeech.py | 1 + .../ASR/zapformer/decode_stream.py | 1 + egs/librispeech/ASR/zapformer/decoder.py | 1 + .../ASR/zapformer/encoder_interface.py | 1 + .../ASR/zapformer/export-onnx-ctc.py | 1 + .../zapformer/export-onnx-streaming-ctc.py | 1 + .../ASR/zapformer/export-onnx-streaming.py | 1 + egs/librispeech/ASR/zapformer/export-onnx.py | 1 + egs/librispeech/ASR/zapformer/export.py | 1 + egs/librispeech/ASR/zapformer/finetune.py | 1 + .../ASR/zapformer/generate_averaged_model.py | 1 + .../ASR/zapformer/jit_pretrained.py | 1 + .../ASR/zapformer/jit_pretrained_ctc.py | 1 + .../ASR/zapformer/jit_pretrained_streaming.py | 1 + egs/librispeech/ASR/zapformer/joiner.py | 1 + .../ASR/zapformer/label_smoothing.py | 1 + egs/librispeech/ASR/zapformer/model.py | 571 ++++++ egs/librispeech/ASR/zapformer/my_profile.py | 1 + egs/librispeech/ASR/zapformer/onnx_check.py | 1 + egs/librispeech/ASR/zapformer/onnx_decode.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + .../zapformer/onnx_pretrained-streaming.py | 1 + .../ASR/zapformer/onnx_pretrained.py | 1 + .../ASR/zapformer/onnx_pretrained_ctc.py | 1 + .../ASR/zapformer/onnx_pretrained_ctc_H.py | 1 + .../ASR/zapformer/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zapformer/onnx_pretrained_ctc_HLG.py | 1 + .../onnx_pretrained_ctc_HLG_streaming.py | 1 + egs/librispeech/ASR/zapformer/optim.py | 1 + egs/librispeech/ASR/zapformer/pretrained.py | 1 + .../ASR/zapformer/pretrained_ctc.py | 1 + egs/librispeech/ASR/zapformer/scaling.py | 1 + .../ASR/zapformer/scaling_converter.py | 1 + .../ASR/zapformer/speech_recognition.py | 229 +++ .../ASR/zapformer/streaming_beam_search.py | 1 + .../ASR/zapformer/streaming_decode.py | 1 + egs/librispeech/ASR/zapformer/subsampling.py | 1 + egs/librispeech/ASR/zapformer/test_scaling.py | 1 + .../ASR/zapformer/test_subsampling.py | 1 + egs/librispeech/ASR/zapformer/train.py | 1690 +++++++++++++++++ egs/librispeech/ASR/zapformer/zipformer.py | 1 + 47 files changed, 2987 insertions(+) create mode 100644 egs/librispeech/ASR/zapformer/.gitignore create mode 100755 egs/librispeech/ASR/zapformer/asr_datamodule.py create mode 120000 egs/librispeech/ASR/zapformer/attention_decoder.py create mode 120000 egs/librispeech/ASR/zapformer/beam_search.py create mode 120000 egs/librispeech/ASR/zapformer/ctc_decode.py create mode 120000 egs/librispeech/ASR/zapformer/decode.py create mode 120000 egs/librispeech/ASR/zapformer/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zapformer/decode_stream.py create mode 120000 egs/librispeech/ASR/zapformer/decoder.py create mode 120000 egs/librispeech/ASR/zapformer/encoder_interface.py create mode 120000 egs/librispeech/ASR/zapformer/export-onnx-ctc.py create mode 120000 egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer/export-onnx-streaming.py create mode 120000 egs/librispeech/ASR/zapformer/export-onnx.py create mode 120000 egs/librispeech/ASR/zapformer/export.py create mode 120000 egs/librispeech/ASR/zapformer/finetune.py create mode 120000 egs/librispeech/ASR/zapformer/generate_averaged_model.py create mode 120000 egs/librispeech/ASR/zapformer/jit_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py create mode 120000 egs/librispeech/ASR/zapformer/joiner.py create mode 120000 egs/librispeech/ASR/zapformer/label_smoothing.py create mode 100755 egs/librispeech/ASR/zapformer/model.py create mode 120000 egs/librispeech/ASR/zapformer/my_profile.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_check.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_decode.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py create mode 120000 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py create mode 120000 egs/librispeech/ASR/zapformer/optim.py create mode 120000 egs/librispeech/ASR/zapformer/pretrained.py create mode 120000 egs/librispeech/ASR/zapformer/pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer/scaling.py create mode 120000 egs/librispeech/ASR/zapformer/scaling_converter.py create mode 100755 egs/librispeech/ASR/zapformer/speech_recognition.py create mode 120000 egs/librispeech/ASR/zapformer/streaming_beam_search.py create mode 120000 egs/librispeech/ASR/zapformer/streaming_decode.py create mode 120000 egs/librispeech/ASR/zapformer/subsampling.py create mode 120000 egs/librispeech/ASR/zapformer/test_scaling.py create mode 120000 egs/librispeech/ASR/zapformer/test_subsampling.py create mode 100755 egs/librispeech/ASR/zapformer/train.py create mode 120000 egs/librispeech/ASR/zapformer/zipformer.py diff --git a/egs/librispeech/ASR/zapformer/.gitignore b/egs/librispeech/ASR/zapformer/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py new file mode 100755 index 0000000000..4db6e101fb --- /dev/null +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -0,0 +1,454 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +# This K2SpeechRecognitionDataset is a modified version of one from +# lhotse.dataset, modified to, in training mode, to return a batch that has 3 +# different copies of the same data with the last two having different Musan +# augmentations and the first having none; and also include the key "num_copies" +# in the batch which would be 1 for the validation data (no Musan) and 3 for the +# training data with musan. +from speech_recognition import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer/attention_decoder.py b/egs/librispeech/ASR/zapformer/attention_decoder.py new file mode 120000 index 0000000000..830180a0cd --- /dev/null +++ b/egs/librispeech/ASR/zapformer/attention_decoder.py @@ -0,0 +1 @@ +../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/beam_search.py b/egs/librispeech/ASR/zapformer/beam_search.py new file mode 120000 index 0000000000..8554e44ccf --- /dev/null +++ b/egs/librispeech/ASR/zapformer/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py new file mode 120000 index 0000000000..a78e5c1df0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/ctc_decode.py @@ -0,0 +1 @@ +../zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py new file mode 120000 index 0000000000..82581c6d36 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -0,0 +1 @@ +../zipformer/decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer/decode_gigaspeech.py new file mode 120000 index 0000000000..63b0ef617b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode_gigaspeech.py @@ -0,0 +1 @@ +../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decode_stream.py b/egs/librispeech/ASR/zapformer/decode_stream.py new file mode 120000 index 0000000000..4e59d04a12 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode_stream.py @@ -0,0 +1 @@ +../zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decoder.py b/egs/librispeech/ASR/zapformer/decoder.py new file mode 120000 index 0000000000..cab465d2b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/encoder_interface.py b/egs/librispeech/ASR/zapformer/encoder_interface.py new file mode 120000 index 0000000000..aa5d0217a8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer/export-onnx-ctc.py new file mode 120000 index 0000000000..dc14e93e75 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..3baa2b673c --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py new file mode 120000 index 0000000000..d18cb9a9a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py new file mode 120000 index 0000000000..f343cf7027 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -0,0 +1 @@ +../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export.py b/egs/librispeech/ASR/zapformer/export.py new file mode 120000 index 0000000000..1a126ab695 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export.py @@ -0,0 +1 @@ +../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/finetune.py b/egs/librispeech/ASR/zapformer/finetune.py new file mode 120000 index 0000000000..0e9e7989b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/finetune.py @@ -0,0 +1 @@ +../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/generate_averaged_model.py b/egs/librispeech/ASR/zapformer/generate_averaged_model.py new file mode 120000 index 0000000000..b65513a058 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/generate_averaged_model.py @@ -0,0 +1 @@ +../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained.py b/egs/librispeech/ASR/zapformer/jit_pretrained.py new file mode 120000 index 0000000000..5d45825206 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py new file mode 120000 index 0000000000..43aeb684bf --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py new file mode 120000 index 0000000000..8e5e6f9812 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/joiner.py b/egs/librispeech/ASR/zapformer/joiner.py new file mode 120000 index 0000000000..444cb5f150 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/label_smoothing.py b/egs/librispeech/ASR/zapformer/label_smoothing.py new file mode 120000 index 0000000000..3690afff9d --- /dev/null +++ b/egs/librispeech/ASR/zapformer/label_smoothing.py @@ -0,0 +1 @@ +../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py new file mode 100755 index 0000000000..56f744d5ea --- /dev/null +++ b/egs/librispeech/ASR/zapformer/model.py @@ -0,0 +1,571 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from encoder_interface import EncoderInterface +from lhotse.dataset import SpecAugment +from scaling import ScaledLinear, convert_num_channels + +from icefall.utils import add_sos, make_pad_mask, time_warp + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.1, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.1, + ) + + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + nn.LogSoftmax(dim=-1), + ) + + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + + self.reconstruction_proj = ScaledLinear( + encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) + self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) + + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) + + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + + src_key_padding_mask = make_pad_mask(x_lens) # (N, T) + specaug_mask = specaug_mask[:, ::2] + assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 + specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens, predict_loss + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). This activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + return ctc_loss + + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.amp.autocast('cuda', enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.amp.autocast('cuda', enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + num_copies: int = 1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + num_copies: + the number of copies of the same data that are in the batch, e.g. 1, 2 + or 3; affects CRCTC, spec-augment, etc. + + Returns: + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + device = x.device + + if num_copies > 1: + assert num_copies == 3 # for now. + # will do SpecAugment. + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + + (batch_size, seq_len, num_channels) = x.shape + B = batch_size // num_copies + x = x.reshape(num_copies, B, seq_len, num_channels) + + # Apply time warping. First append the copies on the channel + # dimension so all copies get the exact same time-warping. + x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments[:B], + ) + x = x.reshape(B, seq_len, num_copies, num_channels) + x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + + # x_no_specaug is several repeats of the 1st copy of the data, which + # is the one not augmented with Musan. But it does have time + # warping. + x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( + B * (num_copies - 1), seq_len, num_channels) + + + # Independently apply frequency masking and time masking to all but the first + # copy of the data. + x = spec_augment(x[1:].reshape(-1, seq_len, num_channels)) + + x_lens = x_lens[:B*(num_copies-1)] + y = y[:B*(num_copies-1)] + else: + x_no_specaug = x + + + # Compute encoder outputs + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + targets = y.values + #if not use_cr_ctc: + #ctc_loss = self.forward_ctc( + #encoder_out=encoder_out, + #encoder_out_lens=encoder_out_lens, + #targets=targets, + #target_lengths=y_lens, + #) + #cr_loss = torch.empty(0) + + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) + + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + + + def forward_reconstruction_loss(self, + log_mels: Tensor, + encoder_out: Tensor, + encoder_out_lens: Tensor): + """ + Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If + use_cr_ctc then we swap the first and second halves of the batch. + + Args: + log_mels: log-mel features of shape (batch_size, T, num_mels) + encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) + """ + batch_size = log_mels.shape[0] + num_mels = log_mels.shape[2] + + pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) + T_embed = pred_mels.shape[1] + pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) + + excess_frames = log_mels.shape[1] - pred_mels.shape[1] + assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. + + T = pred_mels.shape[1] + offset = 3 # i found excess_frames = 5 one time. + log_mels = log_mels[:, offset:offset+T] + + lens = encoder_out_lens * 4 + pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions + assert pad_mask.shape == (batch_size, T) + pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position + # padd_mask: (batch_size, T, 1) + + + # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly + # helps to down-weight the effect of very silent silences. + loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, + reduction='none', beta=1.0) + + # masking. if it's different from the next item on both the frequency dim + # and the time dim, it means we are in neither a time masked nor a frequency masked + # position. + mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), + log_mels != torch.roll(log_mels, 1, dims=1)) + loss = loss * mask.to(loss.dtype) + + loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. + return loss diff --git a/egs/librispeech/ASR/zapformer/my_profile.py b/egs/librispeech/ASR/zapformer/my_profile.py new file mode 120000 index 0000000000..76e48b756b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/my_profile.py @@ -0,0 +1 @@ +../zipformer/my_profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_check.py b/egs/librispeech/ASR/zapformer/onnx_check.py new file mode 120000 index 0000000000..7293c70d46 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_check.py @@ -0,0 +1 @@ +../zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_decode.py b/egs/librispeech/ASR/zapformer/onnx_decode.py new file mode 120000 index 0000000000..9e3faa5e01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_decode.py @@ -0,0 +1 @@ +../zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py new file mode 120000 index 0000000000..f8abb9daa5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py new file mode 120000 index 0000000000..11b846322e --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained.py b/egs/librispeech/ASR/zapformer/onnx_pretrained.py new file mode 120000 index 0000000000..a085def837 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py new file mode 120000 index 0000000000..0c082a204f --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py new file mode 120000 index 0000000000..68102c7374 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py new file mode 120000 index 0000000000..8314b4efdf --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py new file mode 120000 index 0000000000..7a637a1c01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py new file mode 120000 index 0000000000..a5b04b3f8b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/optim.py b/egs/librispeech/ASR/zapformer/optim.py new file mode 120000 index 0000000000..207eecfcda --- /dev/null +++ b/egs/librispeech/ASR/zapformer/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/pretrained.py b/egs/librispeech/ASR/zapformer/pretrained.py new file mode 120000 index 0000000000..70ad71ffc6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/pretrained.py @@ -0,0 +1 @@ +../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/pretrained_ctc.py b/egs/librispeech/ASR/zapformer/pretrained_ctc.py new file mode 120000 index 0000000000..fb9bdf1fa2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/scaling.py b/egs/librispeech/ASR/zapformer/scaling.py new file mode 120000 index 0000000000..58e4b0a0fe --- /dev/null +++ b/egs/librispeech/ASR/zapformer/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/scaling_converter.py b/egs/librispeech/ASR/zapformer/scaling_converter.py new file mode 120000 index 0000000000..bc7c7b5e37 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/scaling_converter.py @@ -0,0 +1 @@ +../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/speech_recognition.py b/egs/librispeech/ASR/zapformer/speech_recognition.py new file mode 100755 index 0000000000..dd069cf3da --- /dev/null +++ b/egs/librispeech/ASR/zapformer/speech_recognition.py @@ -0,0 +1,229 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer/streaming_beam_search.py b/egs/librispeech/ASR/zapformer/streaming_beam_search.py new file mode 120000 index 0000000000..97e6e733f2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/streaming_beam_search.py @@ -0,0 +1 @@ +../zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py new file mode 120000 index 0000000000..e31da07d01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -0,0 +1 @@ +../zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/subsampling.py b/egs/librispeech/ASR/zapformer/subsampling.py new file mode 120000 index 0000000000..d178adc2e5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/test_scaling.py b/egs/librispeech/ASR/zapformer/test_scaling.py new file mode 120000 index 0000000000..b776da79a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/test_scaling.py @@ -0,0 +1 @@ +../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/test_subsampling.py b/egs/librispeech/ASR/zapformer/test_subsampling.py new file mode 120000 index 0000000000..2925ea3c51 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/test_subsampling.py @@ -0,0 +1 @@ +../zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py new file mode 100755 index 0000000000..ce2f9507c9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/train.py @@ -0,0 +1,1690 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset import SpecAugment +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel # TODO: change to model +from optim import Sched3, TransformedAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def get_adjusted_lr_batches(params: AttributeDict) -> float: + # returns an adjusted form of the "lr_batches" parameter used to set the learning + # rate in the Sched3 scheduler. + # We want the final LR to be based on the geometric mean of "how much data we + # have seen" and "how many batches we have seen". + # an easier way to look at it is this: the formula for learning rate depends + # on (cur_batch / lr_batches). if we write this as: + # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches + # then the numerator is a geometric mean of "how many batches we have seen" + # and "how much data we have seen". We can achieve this by setting + # lr_batches = params.lr_batches * (duration_ratio ** -0.5). + duration_ratio = (params.max_duration * params.world_size) / params.ref_duration + lr_batches = params.lr_batches * (duration_ratio ** -0.5) + logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") + return lr_batches + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="3,5,7,5,4,7,5", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--base-dim", + type=int, + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." + ) + + parser.add_argument( + "--embed-multiple", + type=int, + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", + ) + + parser.add_argument( + "--feedforward-multiple", + type=str, + default="3,3,3,3,3,3,3", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,8,4,4", + help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-multiple", + type=str, + default="4,6,9,12,12,9,6", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-multiple", + type=int, + default=8, + help="Factor by which embedding dimension in the decoder model is larger than base-dim.", + ) + + parser.add_argument( + "--joiner-multiple", + type=int, + default=8, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-multiple", + type=int, + default=8, + help="""Factor by which attention decoder dim is larger than base-dim""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-multiple", + type=int, + default=8, + help="""Determines attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-multiple", + type=int, + default=4, + help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--debug-interval", + type=int, + default=10, + help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for + finding parts of the network that are diverging or not well trained. + """ + ) + + parser.add_argument( + "--dump-debug-interval", + type=int, + default=0, + help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they + are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file + opened, seeked-in and closed for each scalar that is written). + """ + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=17500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--reconstruction-loss-scale", + type=float, + default=0.005, + help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", + ) + + parser.add_argument( + "--predict-loss-scale", + type=float, + default=0.01, + help="Prediction of random k-means after widest zipformer layer" + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on pruned loss (for transducer) and the reconstruction and prediction + losses. Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=lookup(params, "embed_dim"), + dropout=0.0, + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + input_dim=lookup(params, "embed_dim"), + output_downsampling_factor=2, + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + pos_head_dim=lookup(params, "pos_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), + pos_dim=params.pos_dim, + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + cnn_module_kernel=lookup(params, "cnn_module_kernel"), + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero + causal=params.causal, + chunk_size=lookup(params, "chunk_size"), + left_context_frames=lookup(params, "left_context_frames"), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "decoder_dim"), + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + joiner = Joiner( + encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, + decoder_dim=lookup(params, "decoder_dim"), + joiner_dim=lookup(params, "joiner_dim"), + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "attention_decoder_dim"), + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=lookup(params, "attention_decoder_attention_dim"), + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), + memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + output_downsampling_factor = 2 + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), + decoder_dim=lookup(params, "decoder_dim"), + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + spec_augment: Optional[SpecAugment] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + spec_augment: + The SpecAugment instance, used for training + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + num_copies = batch["num_copies"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + if num_copies > 1: + assert model.training + # will need the following for time-warping in SpecAugment. + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + spec_augment = None # disable spec-aug + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=80, # for specaug + num_copies=num_copies, + ) + + loss = 0.0 + + adjusted_batch_count = params.batch_idx_train + warm_step = params.warm_step + def warmup_schedule(scale, initial_factor): + # geometric warmup schedules. + warmup_factor = (1. if adjusted_batch_count >= warm_step else + initial_factor + (adjusted_batch_count / warm_step) * (1 - initial_factor)) + return scale * warmup_factor + + if params.use_transducer: + simple_loss_scale = params.simple_loss_scale + pruned_loss_scale = warmup_schedule(1.0, 0.05) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + loss += params.cr_loss_scale * cr_loss + + reconstruction_loss_scale = params.reconstruction_loss_scale + + loss += reconstruction_loss_scale * reconstruction_loss + + loss += params.predict_loss_scale * predict_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + nframes = (feature_lens // params.subsampling_factor).sum().item() + if num_copies > 1: + nframes = nframes * (num_copies - 1) / num_copies # omit 1st copy + info["frames"] = nframes + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + info["cr_loss"] = cr_loss.detach().cpu().item() + info["predict_loss"] = predict_loss.detach().cpu().item() + info["recon_loss"] = reconstruction_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used for CR-CTC. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + if params.debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + optimizer.write_debug_info(summary_writer=tb_writer) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + + assert params.use_ctc # for now, require CTC, we may remove this requirement later. + + spec_augment = get_spec_augment(params) + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + debug_interval=params.debug_interval, + ) + + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and False: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + spec_augment=spec_augment, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/zipformer.py b/egs/librispeech/ASR/zapformer/zipformer.py new file mode 120000 index 0000000000..a064749a48 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zipformer.py @@ -0,0 +1 @@ +../zipformer/zipformer.py \ No newline at end of file From f0b7ccdf645918515361dae8497ecb072612b3f5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Jul 2025 13:21:12 +0800 Subject: [PATCH 0388/1191] Bug fix in exp_augment.py --- icefall/exp_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index b57bf9120e..d63e8cee61 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -109,8 +109,8 @@ def _mask_on_axis(self, # roll half or the mask_starts and mask_ends between the first and second # halves of the batch. this is intended to help CR-CTC, by making the # masked regions of the two augmented versions of the same data anti-correlated. - mask_starts[:, ::2] = mask_starts[:, ::2].roll(batch_size // 2, dim=0) - mask_ends[:, ::2] = mask_ends[:, ::2].roll(batch_size // 2, dim=0) + mask_starts[:, ::2] = mask_starts[:, ::2].roll(B // 2, dim=0) + mask_ends[:, ::2] = mask_ends[:, ::2].roll(B // 2, dim=0) mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) From cf9be950e25d2cb528fc1d96e875c1d444b48750 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Jul 2025 14:21:38 +0800 Subject: [PATCH 0389/1191] Bug fix --- icefall/exp_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index d63e8cee61..9c7e75b8f2 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -109,8 +109,8 @@ def _mask_on_axis(self, # roll half or the mask_starts and mask_ends between the first and second # halves of the batch. this is intended to help CR-CTC, by making the # masked regions of the two augmented versions of the same data anti-correlated. - mask_starts[:, ::2] = mask_starts[:, ::2].roll(B // 2, dim=0) - mask_ends[:, ::2] = mask_ends[:, ::2].roll(B // 2, dim=0) + mask_starts[:, ::2] = mask_starts[:, ::2].roll(B // 2, 0) + mask_ends[:, ::2] = mask_ends[:, ::2].roll(B // 2, 0) mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) From 2111b2e535e7c8ea23cfa9b1d16e60d177101abf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Jul 2025 12:13:51 +0800 Subject: [PATCH 0390/1191] Make the masks completely non-overlapping between sequences. --- icefall/exp_augment.py | 102 ++++++++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 9c7e75b8f2..a2adaaf8c6 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -62,17 +62,15 @@ def forward( if self.num_feature_masks > 0: num_masks = self.num_feature_masks - max_mask_size = F * self.max_feature_mask_fraction / num_masks features = self._mask_on_axis(features, mean, axis=2, - max_mask_size=max_mask_size, + max_mask_fraction=self.max_feature_mask_fraction, num_masks=num_masks) if self.max_frame_mask_fraction > 0: num_masks = max(1, round((T * self.max_frame_mask_fraction) / self.max_frame_mask_size)) - max_mask_size = T * self.max_frame_mask_fraction / num_masks features = self._mask_on_axis(features, mean, axis=1, - max_mask_size=max_mask_size, + max_mask_fraction=self.max_frame_mask_fraction, num_masks=num_masks) features = torch.where(torch.rand(B, 1, 1, **kwargs).expand_as(features) < self.p, @@ -84,7 +82,7 @@ def _mask_on_axis(self, features: torch.Tensor, mean: torch.Tensor, axis: int, - max_mask_size: float, + max_mask_fraction: float, num_masks: int) -> torch.Tensor: """ Mask ``features`` on a particular axis by replacing masked segments of that sequence with @@ -93,7 +91,8 @@ def _mask_on_axis(self, :param features: a batch of feature matrices with shape ``(B, T, F)``. :param mean: the overall feature-matrix mean, a scalar. :param axis: the axis to mask on, i.e. 1 for time, 2 for frequency/feature. - :param masked_fraction: the fraction of the data to mask, in expectation. + :param max_mask_fraction: the maximum fraction of the data to mask (expected value will be + close to half of this.) :param num_masks: the number of masked regions. """ assert axis in [1,2] @@ -104,14 +103,7 @@ def _mask_on_axis(self, M = num_masks N = shape[axis] # T or F - mask_starts, mask_ends = self._sample_mask_starts_and_ends(B, N, num_masks, max_mask_size, device) - - # roll half or the mask_starts and mask_ends between the first and second - # halves of the batch. this is intended to help CR-CTC, by making the - # masked regions of the two augmented versions of the same data anti-correlated. - mask_starts[:, ::2] = mask_starts[:, ::2].roll(B // 2, 0) - mask_ends[:, ::2] = mask_ends[:, ::2].roll(B // 2, 0) - + mask_starts, mask_ends = self._sample_mask_starts_and_ends(B, N, num_masks, max_mask_fraction, device) mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) @@ -159,29 +151,50 @@ def _mask_on_axis(self, return torch.where(is_masked.expand_as(features), mean[None, None, None].expand_as(features), features) - def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_size, device) -> Tuple[Tuple,Tuple]: + def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_fraction, device) -> Tuple[Tuple,Tuple]: # compute the start and end positions of masked regions. this will select mask positions - # that do not overlap. Return: (mask_starts, mask_ends) - - mask_lengths = torch.rand(batch_size, num_masks, device=device) * max_mask_size - mask_tot_len = mask_lengths.sum(dim=1, keepdim=True) # (batch_size, 1) - padding_tot_len = seq_len - mask_tot_len # (batch_size, 1) + # that do not overlap. Return: (mask_starts, mask_ends). + + # we sample the masks for pairs of sequences. + B = (batch_size + 1) // 2 + # M is the number of masks we sample for each pair of sequences. + M = 2 * num_masks + + # "rlength" means relative length of each mask, i.e. relative to seq_len. the + # lengths in mask_lengths are normalized lengths. + mask_rlengths = torch.rand(B, M, device=device) * (max_mask_fraction / num_masks) + mask_tot_rlen = mask_rlengths.sum(dim=1, keepdim=True) # (batch_size, 1) + + # padding_tot_rlen is the total relative length of the padding segmnts. We clamp to min=0.25 + # so there is some randomness in the positions even if the selected masks are unusually large. + # (note: we expect the max_fraction values to be between about .5 to .7, so the expected-masked-fraction + # values would be about .25 to 0.35 (since we sample between 0 and maximum); and if we double + # it because we do the selection for pairs of masked regions, that gives us about .5 to .7. + # so definitely this clamping will happen for less than half of the pairs of sequences. + + padding_tot_rlen = (1. - mask_tot_rlen).clamp(min=0.2) # (batch_size, 1) eps = 1.0e-20 - # get padding lengths by randomly placing dividers on the line of length "padding_tot_len" - # these "padding_positions" are not absolute position on the line from 0 to seq_len, - # but positions on the line from 0 to "padding_tot_len" which divides up the length - # we need to pad. - num_pads = num_masks + 1 - padding_positions = torch.rand(batch_size, num_pads - 1, device=device) * padding_tot_len - padding_positions = padding_positions.sort(dim=1)[0] - zero = torch.zeros(batch_size, 1, device=device) - padding_positions = torch.cat((zero, padding_positions, padding_tot_len), dim=1) - padding_lengths = padding_positions[:, 1:] - padding_positions[:, :-1] - - lengths = torch.empty(batch_size, num_masks * 2 + 1, device=device) - lengths[:, 1::2] = mask_lengths - lengths[:, 0::2] = padding_lengths + # get padding lengths by randomly placing dividers on the line of length "padding_tot_rlen" + # P is the number of padding regions for each pair of sequences. + P = M + 1 + # rpositions means positions expressed in relative length, i.e. normalized so that + # seq_len is 1. + padding_rpositions = torch.rand(B, P - 1, device=device) * padding_tot_rlen + padding_rpositions = padding_rpositions.sort(dim=1)[0] + zero = torch.zeros(B, 1, device=device) + padding_rpositions = torch.cat((zero, padding_rpositions, padding_tot_rlen), dim=1) + padding_rlengths = padding_rpositions[:, 1:] - padding_rpositions[:, :-1] + + # 'rlengths' are the normalized lengths of the padding regions and the masks. + rlengths = torch.empty(B, 2 * M + 1, device=device) + rlengths[:, 1::2] = mask_rlengths + rlengths[:, 0::2] = padding_rlengths + + # lengths is the lengths of the masks and padding regions, converted to absolute + # length. We have to normalize before multiplying by seq_len because of the .clamp() + # operation above-- not all sequences will sum to one. + lengths = (rlengths / rlengths.sum(dim=1, keepdim=True)) * seq_len positions = torch.cumsum(lengths, dim=1) # last element of 'positions' should be seq_len @@ -190,7 +203,20 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # positions does not have a leading zero, cumsum is inclusive; but do not treat final `seq_len` as a mask start position. mask_starts = positions[:, 0:-1:2] mask_ends = positions[:, 1::2] - assert mask_starts.shape == (batch_size, num_masks) and mask_ends.shape == (batch_size, num_masks) + assert mask_starts.shape == (B, M) and mask_ends.shape == (B, M) + + + # letting A,B be randomly 0 or 1 avoids any overall bias towards the start or end of the + # sequence in case the batch size is odd. + A = random.randint(0, 1) + B = (A + 1) % 2 + mask_starts1 = mask_starts[:, A::2] + mask_ends1 = mask_ends[:, A::2] + mask_starts2 = mask_starts[:, B::2] + mask_ends2 = mask_ends[:, B::2] + + mask_starts = torch.cat((mask_starts1, mask_starts2), dim=0)[:batch_size] + mask_ends = torch.cat((mask_ends1, mask_ends2), dim=0)[:batch_size] return mask_starts, mask_ends @@ -213,11 +239,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]): def _test_exp_augment(): for n in [ 0, 1 ]: #device = 'cuda' - B, T, F = 300, 600, 80 + B, T, F = 301, 600, 80 device = 'cpu' if n == 0: - exp_augment = ExpAugment(p=1.0) #, max_frame_mask_size=2.0, max_frame_mask_fraction=0.02) + exp_augment = ExpAugment() #, max_frame_mask_size=2.0, max_frame_mask_fraction=0.02) else: from lhotse.dataset import SpecAugment time_mask_ratio = 3.5 @@ -234,7 +260,7 @@ def _test_exp_augment(): num_feature_masks=2, frames_mask_size=100, max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 - p=1.0, + p=0.9, ) supervision_segments = torch.stack(( torch.arange(B, device=device), # sequence_idx From 1ce2a50e68319df2504f227c853858395086d6c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Jul 2025 19:12:30 +0800 Subject: [PATCH 0391/1191] Add soft link so multi-job runs can continue. --- egs/librispeech/ASR/zipformer/speech_recognition.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/librispeech/ASR/zipformer/speech_recognition.py diff --git a/egs/librispeech/ASR/zipformer/speech_recognition.py b/egs/librispeech/ASR/zipformer/speech_recognition.py new file mode 120000 index 0000000000..cb33e61085 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/speech_recognition.py @@ -0,0 +1 @@ +../zapformer/speech_recognition.py \ No newline at end of file From 66499ace8c5b35a9e42d0b99cdd7c5581567b074 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Jul 2025 21:42:18 +0800 Subject: [PATCH 0392/1191] Increase max_frame_mask_fraction from 0.525 to 0.675. --- icefall/exp_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index a2adaaf8c6..92f8f30ae3 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -14,7 +14,7 @@ def __init__( self, max_feature_mask_fraction: float = 0.675, # max fraction that can possibly be masked num_feature_masks: int = 2, - max_frame_mask_fraction: float = 0.525, + max_frame_mask_fraction: float = 0.675, max_frame_mask_size: float = 100, # max size in frames of temporal masks. p=0.9, # probability of doing augmentation ): From 6867543239e12c3e412531c5d7def45f416cde79 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Jul 2025 10:59:06 +0800 Subject: [PATCH 0393/1191] Edit zapformer/model.py and zapformer/train.py to use ExpAugment not SpecAugment. --- egs/librispeech/ASR/zapformer/decode.py | 1090 ++++++++++++++++++++++- egs/librispeech/ASR/zapformer/model.py | 17 +- egs/librispeech/ASR/zapformer/train.py | 47 +- 3 files changed, 1110 insertions(+), 44 deletions(-) mode change 120000 => 100755 egs/librispeech/ASR/zapformer/decode.py diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py deleted file mode 120000 index 82581c6d36..0000000000 --- a/egs/librispeech/ASR/zapformer/decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py new file mode 100755 index 0000000000..504d1d94d2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -0,0 +1,1089 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens, _predict_loss = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 56f744d5ea..d2d1e3f34f 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -23,7 +23,6 @@ import torch.nn as nn from torch import Tensor from encoder_interface import EncoderInterface -from lhotse.dataset import SpecAugment from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp @@ -367,7 +366,7 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, num_copies: int = 1, @@ -392,16 +391,16 @@ def forward( The scale to smooth the loss with lm (output of predictor network) part spec_augment: - The SpecAugment instance that returns time masks, - used only if use_cr_ctc is True. + The SpecAugment instance, or similar/compatible object, that masks + log-mel features. supervision_segments: An int tensor of shape ``(S, 3)``. ``S`` is the number of - supervision segments that exist in ``features``. - Used only if use_cr_ctc is True. + supervision segments that exist in ``features``. Used only for + time-warping, if num_copies > 1. time_warp_factor: Parameter for the time warping; larger values mean more warping. Set to ``None``, or less than ``1``, to disable. - Used only if use_cr_ctc is True. + Used only if num_copies > 1, corresponds to training mode. num_copies: the number of copies of the same data that are in the batch, e.g. 1, 2 or 3; affects CRCTC, spec-augment, etc. @@ -427,8 +426,8 @@ def forward( if num_copies > 1: assert num_copies == 3 # for now. - # will do SpecAugment. - assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # will do SpecAugment or similar. + assert spec_augment is not None and getattr(spec_augment, 'time_warp_factor', -1) < 0 (batch_size, seq_len, num_channels) = x.shape B = batch_size // num_copies diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ce2f9507c9..de6850105a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -71,10 +71,9 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut -from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import AsrModel # TODO: change to model +from model import AsrModel from optim import Sched3, TransformedAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -94,6 +93,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error +from icefall.exp_augment import ExpAugment # using this, not lhotse's version of SpecAugment from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -542,13 +542,6 @@ def get_parser(): help="Prediction of random k-means after widest zipformer layer" ) - parser.add_argument( - "--time-mask-ratio", - type=float, - default=2.5, - help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", - ) - parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -820,24 +813,6 @@ def get_model(params: AttributeDict) -> nn.Module: return model -def get_spec_augment(params: AttributeDict) -> SpecAugment: - num_frame_masks = int(10 * params.time_mask_ratio) - max_frames_mask_fraction = 0.15 * params.time_mask_ratio - logging.info( - f"num_frame_masks: {num_frame_masks}, " - f"max_frames_mask_fraction: {max_frames_mask_fraction}" - ) - spec_augment = SpecAugment( - time_warp_factor=0, # Do time warping in model.py - num_frame_masks=num_frame_masks, # default: 10 - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 - ) - return spec_augment - - def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -960,7 +935,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -978,7 +953,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. spec_augment: - The SpecAugment instance, used for training + The SpecAugment instance (or similar object), used for training """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -1043,13 +1018,15 @@ def warmup_schedule(scale, initial_factor): if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss - loss += params.cr_loss_scale * cr_loss + if num_copies > 1: + loss += params.cr_loss_scale * cr_loss reconstruction_loss_scale = params.reconstruction_loss_scale loss += reconstruction_loss_scale * reconstruction_loss - loss += params.predict_loss_scale * predict_loss + if num_copies > 1: + loss += params.predict_loss_scale * predict_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -1071,8 +1048,10 @@ def warmup_schedule(scale, initial_factor): info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() - info["cr_loss"] = cr_loss.detach().cpu().item() - info["predict_loss"] = predict_loss.detach().cpu().item() + if num_copies > 1: + info["cr_loss"] = cr_loss.detach().cpu().item() + if num_copies > 1: + info["predict_loss"] = predict_loss.detach().cpu().item() info["recon_loss"] = reconstruction_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -1408,7 +1387,7 @@ def run(rank, world_size, args): assert params.use_ctc # for now, require CTC, we may remove this requirement later. - spec_augment = get_spec_augment(params) + spec_augment = ExpAugment() assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None From 838ca3f509a20230f98059f4a8b044ec03fbdb8d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Jul 2025 11:19:43 +0800 Subject: [PATCH 0394/1191] Bug fixes --- egs/librispeech/ASR/zapformer/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index de6850105a..69e070eedc 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -93,7 +93,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.exp_augment import ExpAugment # using this, not lhotse's version of SpecAugment +from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -953,7 +953,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. spec_augment: - The SpecAugment instance (or similar object), used for training + The nn.Module instance (or similar object), used for training """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -973,7 +973,7 @@ def compute_loss( if num_copies > 1: assert model.training - # will need the following for time-warping in SpecAugment. + # will need the following for time-warping in nn.Module. supervision_intervals = batch["supervisions"] supervision_segments = torch.stack( [ @@ -1102,7 +1102,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1130,7 +1130,7 @@ def train_one_epoch( scaler: The scaler used for mix precision training. spec_augment: - The SpecAugment instance used for CR-CTC. + The SpecAugment or similar instance used for CR-CTC. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1608,7 +1608,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, - spec_augment: Optional[SpecAugment] = None, + spec_augment: Optional[nn.Module] = None, ): from lhotse.dataset import find_pessimistic_batches From 06ae5bf699473da3d08aa9bec94b944a311fce7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Jul 2025 12:10:18 +0800 Subject: [PATCH 0395/1191] Merge chnges from 871 and 875, so we do only halfway-normalization over the batch dim, not full normalization. --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 82492d6f68..0e42fab0be 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -544,8 +544,12 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, return torch.tensor(0.0, device=x.device) def mean_and_variance_norm(x): - mean = x.mean(dim=list(range(x.ndim-1))) + mean_dims = list([ i for i in range(x.ndim-1) if i != batch_dim ]) + mean = x.mean(dim=mean_dims, keepdim=True) x = x - mean + # go halfway towards also normalizing across sequences, so + # it will keep half of the within-sequence normalization. + x = x - (0.5 * mean.mean(dim=batch_dim, keepdim=True)) eps = 1.0e-08 stddev = ((x ** 2).mean(dim=list(range(x.ndim-1))) + eps).sqrt() x = x / stddev From 1d0394abe0317c668bec41a8cc4fbbff4ff2b45e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Jul 2025 13:22:04 +0800 Subject: [PATCH 0396/1191] Implement volume normalization in reconstruction loss. --- egs/librispeech/ASR/zapformer/model.py | 75 ++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index d2d1e3f34f..26623d8183 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -127,7 +127,6 @@ def __init__( self.reconstruction_proj = ScaledLinear( encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) - self.reconstruction_loss = torch.nn.SmoothL1Loss(reduction='none', beta=1.0) def forward_encoder( @@ -556,15 +555,71 @@ def forward_reconstruction_loss(self, # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly # helps to down-weight the effect of very silent silences. - loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, - reduction='none', beta=1.0) - - # masking. if it's different from the next item on both the frequency dim - # and the time dim, it means we are in neither a time masked nor a frequency masked - # position. - mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), - log_mels != torch.roll(log_mels, 1, dims=1)) - loss = loss * mask.to(loss.dtype) + #loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, + # reduction='none', beta=1.0) + # this way of applying the padding mask is not really ideal in terms of normalization, + # it will cause us to under-normalize a bit. + diff = log_mels * pad_mask - pred_mels * pad_mask + # mean over sequence and mel-bin dims but not batch. + loss = smooth_l1_loss_mod(diff, beta=1.0, norm_dims=(1, 2)) + + # removing the masking logic since we now use the no-specaug reference sequence. + ## masking. if it's different from the next item on both the frequency dim + ## and the time dim, it means we are in neither a time masked nor a frequency masked + ## position. + #mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), + # log_mels != torch.roll(log_mels, 1, dims=1)) + #loss = loss * mask.to(loss.dtype) loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. return loss + + + +def smooth_l1_loss_mod(diffs: Tensor, beta: float = 1.0, + norm_dims: Optional[Tuple[int]] = None): + """ + This is similar to : + loss = torch.nn.SmoothL1Loss(reduction='none', beta=beta) + loss(a, b) is similar to smooth_l1_loss_mod(a - b), + except that it does an optional normalization step that involves + subtracting a mean computed over 'norm_dims'. + """ + assert beta > 0 + # torch.nn.SmoothL1Loss(reduction='none', beta=beta) is: + # l_n = 0.5 * (diff^2 / beta) if |diff| < beta + # else: |diff| - 0.5 / beta + diffs_abs = diffs.abs() + l2_loss = (0.5 / beta) * (diffs ** 2) + l1_loss = diffs.abs() - (0.5 * beta) + # 'scale' is a loss scale such that if we multiply l2_loss by it, + # we get the final loss. + scale = l1_loss.clamp(min=0.5 * beta) / l2_loss.clamp(min=0.5 * beta) + diffs_scaled = scale.sqrt() * diffs + # ok, now we can treat the loss as (0.5 / beta) * diffs_scaled ** 2 + if norm_dims: + diffs_scaled = diffs_scaled - diffs_scaled.mean(dim=norm_dims, keepdim=True) + + loss = (0.5 / beta) * (diffs_scaled ** 2) + return loss + + + +def _test_smooth_l1_loss_mod(): + a = torch.randn(2, 50) + b = torch.randn(2, 50) + + beta = 2.0 + loss = torch.nn.SmoothL1Loss(reduction='none', beta=beta) + loss1 = loss(a, b) + loss2 = smooth_l1_loss_mod(a - b, beta=beta) + #print(f"loss1={loss1}, loss2={loss2}") + assert torch.allclose(loss1, loss2, atol=0.001) + + loss2_norm = smooth_l1_loss_mod(a - b, beta=beta, norm_dims=(1,)) + print(f"loss2-mean={loss2.mean()}, loss2_norm-mean={loss2_norm.mean()}") + assert loss2_norm.mean() <= loss2.mean() + + +if __name__ == '__main__': + _test_smooth_l1_loss_mod() From abdc2700f2f3a7abaf778a13faef3aaf30a2348d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Jul 2025 13:22:12 +0800 Subject: [PATCH 0397/1191] Implement volume normalization in reconstruction loss. --- egs/librispeech/ASR/zapformer/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 26623d8183..8ea8b1edde 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -561,6 +561,9 @@ def forward_reconstruction_loss(self, # it will cause us to under-normalize a bit. diff = log_mels * pad_mask - pred_mels * pad_mask # mean over sequence and mel-bin dims but not batch. + # this smooth_l1_loss_mod is intended to accomplish volume normalization at the + # sequence level, i.e. in case the differently-augmented signals have a difference in volume, + # which could happen due to musan augmentation. loss = smooth_l1_loss_mod(diff, beta=1.0, norm_dims=(1, 2)) # removing the masking logic since we now use the no-specaug reference sequence. From 0b34bc9751622ef2770469ad5cb2fab065a6cce3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Jul 2025 15:00:10 +0800 Subject: [PATCH 0398/1191] Bug fixes in modified smoothl1loss --- egs/librispeech/ASR/zapformer/model.py | 29 ++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 8ea8b1edde..d3a067ec09 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -589,28 +589,31 @@ def smooth_l1_loss_mod(diffs: Tensor, beta: float = 1.0, subtracting a mean computed over 'norm_dims'. """ assert beta > 0 - # torch.nn.SmoothL1Loss(reduction='none', beta=beta) is: - # l_n = 0.5 * (diff^2 / beta) if |diff| < beta - # else: |diff| - 0.5 / beta - diffs_abs = diffs.abs() - l2_loss = (0.5 / beta) * (diffs ** 2) - l1_loss = diffs.abs() - (0.5 * beta) - # 'scale' is a loss scale such that if we multiply l2_loss by it, - # we get the final loss. - scale = l1_loss.clamp(min=0.5 * beta) / l2_loss.clamp(min=0.5 * beta) - diffs_scaled = scale.sqrt() * diffs + def get_scale(diffs): + # torch.nn.SmoothL1Loss(reduction='none', beta=beta) is: + # l_n = 0.5 * (diff^2 / beta) if |diff| < beta + # else: |diff| - 0.5 / beta + diffs_abs = diffs.abs() + l2_loss = (0.5 / beta) * (diffs ** 2) + l1_loss = diffs.abs() - (0.5 * beta) + # 'scale' is a loss scale such that if we multiply l2_loss by it, + # we get the final loss. + scale = l1_loss.clamp(min=0.5 * beta) / l2_loss.clamp(min=0.5 * beta) + return scale.sqrt() # ok, now we can treat the loss as (0.5 / beta) * diffs_scaled ** 2 if norm_dims: - diffs_scaled = diffs_scaled - diffs_scaled.mean(dim=norm_dims, keepdim=True) + scale = get_scale(diffs) + offset = (scale * diffs).mean(dim=norm_dims, keepdim=True) / scale.mean(dim=norm_dims, keepdim=True) + diffs = diffs - offset - loss = (0.5 / beta) * (diffs_scaled ** 2) + loss = (0.5 / beta) * ((diffs * get_scale(diffs)) ** 2) return loss def _test_smooth_l1_loss_mod(): a = torch.randn(2, 50) - b = torch.randn(2, 50) + b = torch.randn(4, 50) + 10. * torch.randn(4, 1) beta = 2.0 loss = torch.nn.SmoothL1Loss(reduction='none', beta=beta) From a1a1f865611b8966ff6cbf29068c3ab3857f0748 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Jul 2025 15:00:43 +0800 Subject: [PATCH 0399/1191] Bug fixes in modified smoothl1loss --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index d3a067ec09..7573778c81 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -612,7 +612,7 @@ def get_scale(diffs): def _test_smooth_l1_loss_mod(): - a = torch.randn(2, 50) + a = torch.randn(4, 50) b = torch.randn(4, 50) + 10. * torch.randn(4, 1) beta = 2.0 From c9f947b7ed4a368e40f4f7bb356d1f8caf0fb209 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Jul 2025 11:02:37 +0800 Subject: [PATCH 0400/1191] Change max_frame_mask_fraction from .675 to .725 and max_frame_mask_size from 100 to 70. --- icefall/exp_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 92f8f30ae3..695ecd604a 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -14,8 +14,8 @@ def __init__( self, max_feature_mask_fraction: float = 0.675, # max fraction that can possibly be masked num_feature_masks: int = 2, - max_frame_mask_fraction: float = 0.675, - max_frame_mask_size: float = 100, # max size in frames of temporal masks. + max_frame_mask_fraction: float = 0.725, + max_frame_mask_size: float = 70, # max size in frames of temporal masks. p=0.9, # probability of doing augmentation ): super().__init__() From f61d232c3c9e5401b91647274875731a423843ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Jul 2025 12:27:19 +0800 Subject: [PATCH 0401/1191] Make 0.5 of the over-normalization be at sequence level and 1.0 at batch level, i.e. swapping the factors. --- egs/librispeech/ASR/zipformer/scaling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 0e42fab0be..53d9b39b15 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -546,10 +546,10 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, def mean_and_variance_norm(x): mean_dims = list([ i for i in range(x.ndim-1) if i != batch_dim ]) mean = x.mean(dim=mean_dims, keepdim=True) - x = x - mean - # go halfway towards also normalizing across sequences, so - # it will keep half of the within-sequence normalization. - x = x - (0.5 * mean.mean(dim=batch_dim, keepdim=True)) + # over-normalization, by, totally a factor of 1.5, of which 0.5 is at + # sequence level and 1.0 is at batch level. + x = x - 0.5 * mean + x = x - mean.mean(dim=batch_dim, keepdim=True) eps = 1.0e-08 stddev = ((x ** 2).mean(dim=list(range(x.ndim-1))) + eps).sqrt() x = x / stddev From 1e7d80b5c4d383d70170ac619cc1d2383393a4ca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Jul 2025 19:45:45 +0800 Subject: [PATCH 0402/1191] Do over-normalization at only whole batch level, not sequence level. --- egs/librispeech/ASR/zipformer/scaling.py | 70 +++++++++++++++++------- 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 53d9b39b15..e3ef6503ec 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -534,24 +534,19 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, - batch_dim: int, name: str, +def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, name: str, mask: Optional[Tensor]) -> Tensor: - batch_size = x.shape[batch_dim] + batch_size = x.shape[1] if batch_size % 2 != 0: assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) def mean_and_variance_norm(x): - mean_dims = list([ i for i in range(x.ndim-1) if i != batch_dim ]) - mean = x.mean(dim=mean_dims, keepdim=True) - # over-normalization, by, totally a factor of 1.5, of which 0.5 is at - # sequence level and 1.0 is at batch level. - x = x - 0.5 * mean - x = x - mean.mean(dim=batch_dim, keepdim=True) + mean = x.mean(dim=(0, 1), keepdim=True) # mean on sequence and batch dim + x = x - 1.5 * mean eps = 1.0e-08 - stddev = ((x ** 2).mean(dim=list(range(x.ndim-1))) + eps).sqrt() + stddev = ((x ** 2).mean(dim=(0,1)) + eps).sqrt() x = x / stddev return x @@ -565,9 +560,8 @@ def mean_and_variance_norm(x): indexes = torch.max(x_proj, dim=-1)[1] - indexes = torch.roll(indexes, batch_size // 2, batch_dim) # predict index of the other masked copy. - x_pred = predictor(x) - logprobs = x_pred.log_softmax(dim=-1) + indexes = torch.roll(indexes, batch_size // 2, 1) # predict index of the other masked copy. + logprobs = predictor(x) loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) if random.random() < 0.002: @@ -578,11 +572,46 @@ def mean_and_variance_norm(x): # we also swap the mask over the two copies of the data; the mask goes with the thing that # is predicted, not the thing we predict it from.. the idea being that we don't want to ask # the model to predict masked portions of the time sequence. - mask = torch.roll(mask, batch_size // 2, batch_dim) + mask = torch.roll(mask, batch_size // 2, 1) loss = loss * mask.unsqueeze(-1) return loss.sum() # we reduce with sum in what we return. +class Predictor(nn.Module): + """ + A simple feedforward module used in PredictLoss to predict codebook entries derived from the other copy of the data's + embeddings. + """ + def __init__(self, + num_channels: int, + num_hidden: int, + codebook_size: int): + super().__init__() + self.in_proj = nn.Linear(num_channels, num_hidden) + self.self_mean_proj = nn.Linear(num_channels, num_hidden) + self.other_mean_proj = nn.Linear(num_channels, num_hidden) + self.activation = SwashR() + self.out_proj = nn.Linear(num_hidden, codebook_size) + + def forward(self, + x: Tensor): + """ + Args: + x: (seq_len, batch_size, num_channels), batch_size must be even. + Returns: + normalized codebook logprobs, dim: (seq_len, batch_size, codebook_size) + """ + (seq_len, batch_size, num_channels) = x.shape + assert batch_size % 2 == 0 + x_mean = x.mean(dim=0, keepdim=True) + # I am cautious about providing the other mean non-detached.. + x_mean_swapped = x_mean.detach().roll(batch_size // 2, 1) + x = self.in_proj(x) + self.self_mean_proj(x_mean) + self.other_mean_proj(x_mean_swapped) + x = self.activation(x) + x = self.out_proj(x) + x = x.log_softmax(dim=-1) + return x + class PredictLoss(nn.Module): """ Adds an auxiliary loss based on predicting the top-1 of randomized codebook @@ -592,7 +621,6 @@ class PredictLoss(nn.Module): """ def __init__(self, num_channels: int, - batch_dim: int = 0, codebook_size: int = 64): super().__init__() scale = num_channels ** -0.5 @@ -600,18 +628,18 @@ def __init__(self, scale * torch.randn(codebook_size, num_channels), persistent=True) num_hidden = max(1024, num_channels) - self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), - SwashR(), - nn.Linear(num_hidden, codebook_size)) - self.batch_dim = batch_dim + # num_channels * 2 because we also provide the sequence-level difference + # in means between the two copies, detached, to help it normalize + # for things like differences in frequency masks and volume. + self.predictor = Predictor(num_channels, num_hidden, codebook_size) self.name = None # will be set from training code def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - # x is of shape (..., num_channels); mask is of shape (...), i.e. + # x is of shape (seq_len, batch_size, num_channels); mask is of shape (seq_len, batch_size), i.e. # it matches x except is missing the last dim. return predict_loss(x, self.predictor, self.proj_weight, - self.batch_dim, self.name, mask) + self.name, mask) From 4858c1508b580e18326d8cdf1f9615055807fe05 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Jul 2025 19:54:47 +0800 Subject: [PATCH 0403/1191] Fix to last commit --- egs/librispeech/ASR/zipformer/scaling.py | 67 +++++++----------------- 1 file changed, 18 insertions(+), 49 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e3ef6503ec..95d689105d 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -534,19 +534,21 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, name: str, +def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, + batch_dim: int, name: str, mask: Optional[Tensor]) -> Tensor: - batch_size = x.shape[1] + # caution: now require input to be either (batch, seq, channel) or (seq, batch, channel) + batch_size = x.shape[batch_dim] if batch_size % 2 != 0: assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) def mean_and_variance_norm(x): - mean = x.mean(dim=(0, 1), keepdim=True) # mean on sequence and batch dim - x = x - 1.5 * mean + mean = x.mean(dim=(0,1), keepdim=True) + x = x - 1.5 * mean # over-normalization. eps = 1.0e-08 - stddev = ((x ** 2).mean(dim=(0,1)) + eps).sqrt() + stddev = ((x ** 2).mean(dim=(0, 1)) + eps).sqrt() x = x / stddev return x @@ -560,8 +562,9 @@ def mean_and_variance_norm(x): indexes = torch.max(x_proj, dim=-1)[1] - indexes = torch.roll(indexes, batch_size // 2, 1) # predict index of the other masked copy. - logprobs = predictor(x) + indexes = torch.roll(indexes, batch_size // 2, batch_dim) # predict index of the other masked copy. + x_pred = predictor(x) + logprobs = x_pred.log_softmax(dim=-1) loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) if random.random() < 0.002: @@ -572,46 +575,11 @@ def mean_and_variance_norm(x): # we also swap the mask over the two copies of the data; the mask goes with the thing that # is predicted, not the thing we predict it from.. the idea being that we don't want to ask # the model to predict masked portions of the time sequence. - mask = torch.roll(mask, batch_size // 2, 1) + mask = torch.roll(mask, batch_size // 2, batch_dim) loss = loss * mask.unsqueeze(-1) return loss.sum() # we reduce with sum in what we return. -class Predictor(nn.Module): - """ - A simple feedforward module used in PredictLoss to predict codebook entries derived from the other copy of the data's - embeddings. - """ - def __init__(self, - num_channels: int, - num_hidden: int, - codebook_size: int): - super().__init__() - self.in_proj = nn.Linear(num_channels, num_hidden) - self.self_mean_proj = nn.Linear(num_channels, num_hidden) - self.other_mean_proj = nn.Linear(num_channels, num_hidden) - self.activation = SwashR() - self.out_proj = nn.Linear(num_hidden, codebook_size) - - def forward(self, - x: Tensor): - """ - Args: - x: (seq_len, batch_size, num_channels), batch_size must be even. - Returns: - normalized codebook logprobs, dim: (seq_len, batch_size, codebook_size) - """ - (seq_len, batch_size, num_channels) = x.shape - assert batch_size % 2 == 0 - x_mean = x.mean(dim=0, keepdim=True) - # I am cautious about providing the other mean non-detached.. - x_mean_swapped = x_mean.detach().roll(batch_size // 2, 1) - x = self.in_proj(x) + self.self_mean_proj(x_mean) + self.other_mean_proj(x_mean_swapped) - x = self.activation(x) - x = self.out_proj(x) - x = x.log_softmax(dim=-1) - return x - class PredictLoss(nn.Module): """ Adds an auxiliary loss based on predicting the top-1 of randomized codebook @@ -621,6 +589,7 @@ class PredictLoss(nn.Module): """ def __init__(self, num_channels: int, + batch_dim: int = 0, codebook_size: int = 64): super().__init__() scale = num_channels ** -0.5 @@ -628,18 +597,18 @@ def __init__(self, scale * torch.randn(codebook_size, num_channels), persistent=True) num_hidden = max(1024, num_channels) - # num_channels * 2 because we also provide the sequence-level difference - # in means between the two copies, detached, to help it normalize - # for things like differences in frequency masks and volume. - self.predictor = Predictor(num_channels, num_hidden, codebook_size) + self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), + SwashR(), + nn.Linear(num_hidden, codebook_size)) + self.batch_dim = batch_dim self.name = None # will be set from training code def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - # x is of shape (seq_len, batch_size, num_channels); mask is of shape (seq_len, batch_size), i.e. + # x is of shape (..., num_channels); mask is of shape (...), i.e. # it matches x except is missing the last dim. return predict_loss(x, self.predictor, self.proj_weight, - self.name, mask) + self.batch_dim, self.name, mask) From d1a9224b2c56bd6edf9258380e27990e2a53f96f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Jul 2025 14:35:29 +0800 Subject: [PATCH 0404/1191] Implement convolutional predictor in PredictLoss. --- egs/librispeech/ASR/zipformer/scaling.py | 81 ++++++++++++++++++---- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 95d689105d..cdb7c142f6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -535,10 +535,10 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, - batch_dim: int, name: str, + name: str, mask: Optional[Tensor]) -> Tensor: - # caution: now require input to be either (batch, seq, channel) or (seq, batch, channel) - batch_size = x.shape[batch_dim] + # caution: now require input to be (seq, batch, channel) + batch_size = x.shape[1] if batch_size % 2 != 0: assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." @@ -562,7 +562,7 @@ def mean_and_variance_norm(x): indexes = torch.max(x_proj, dim=-1)[1] - indexes = torch.roll(indexes, batch_size // 2, batch_dim) # predict index of the other masked copy. + indexes = torch.roll(indexes, batch_size // 2, 1) x_pred = predictor(x) logprobs = x_pred.log_softmax(dim=-1) loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) @@ -575,11 +575,70 @@ def mean_and_variance_norm(x): # we also swap the mask over the two copies of the data; the mask goes with the thing that # is predicted, not the thing we predict it from.. the idea being that we don't want to ask # the model to predict masked portions of the time sequence. - mask = torch.roll(mask, batch_size // 2, batch_dim) + mask = torch.roll(mask, batch_size // 2, 1) loss = loss * mask.unsqueeze(-1) return loss.sum() # we reduce with sum in what we return. + +class PredictorConvModule(nn.Module): + """A convolution module with a residual connecction, modified from ConvolutionModule in Zipformer2, that is used as + the predictor network in class Predictor. The input format is (seq, batch, channels). + + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + out_channels: int, + ) -> None: + """Construct a ConvolutionModule object.""" + super().__init__() + assert (kernel_size - 1) % 2 == 0 + + self.in_proj = nn.Linear( + channels, + hidden_channels, + ) + + self.depthwise_conv = nn.Conv1d( + in_channels=hidden_channels, + out_channels=hidden_channels, + groups=hidden_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + self.out_proj = ActivationDropoutAndLinear( + hidden_channels, + out_channels, + activation="SwashR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + ) -> Tensor: + x = self.in_proj(x) # (time, batch, 2*channels) + x = x.permute(1, 2, 0) # (#batch, channels, time). + x = self.depthwise_conv(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + x = self.out_proj(x) # includes activation. + return x + + + class PredictLoss(nn.Module): """ Adds an auxiliary loss based on predicting the top-1 of randomized codebook @@ -589,18 +648,16 @@ class PredictLoss(nn.Module): """ def __init__(self, num_channels: int, - batch_dim: int = 0, codebook_size: int = 64): super().__init__() scale = num_channels ** -0.5 self.register_buffer('proj_weight', scale * torch.randn(codebook_size, num_channels), persistent=True) - num_hidden = max(1024, num_channels) - self.predictor = nn.Sequential(nn.Linear(num_channels, num_hidden), - SwashR(), - nn.Linear(num_hidden, codebook_size)) - self.batch_dim = batch_dim + num_hidden = max(512, num_channels) + kernel_size = 7 + self.predictor = PredictorConvModule(num_channels, num_hidden, kernel_size, codebook_size) + self.name = None # will be set from training code def forward(self, @@ -608,7 +665,7 @@ def forward(self, # x is of shape (..., num_channels); mask is of shape (...), i.e. # it matches x except is missing the last dim. return predict_loss(x, self.predictor, self.proj_weight, - self.batch_dim, self.name, mask) + self.name, mask) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8111379b3d..8b7bb6b010 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -756,7 +756,7 @@ def __init__( grad_scale=0.025, ) - self.predict_loss = PredictLoss(dim, batch_dim=1) + self.predict_loss = PredictLoss(dim) def forward( From 6565156b1aa2f437c9a8fcae3ee339aa53ed61e4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Jul 2025 15:56:13 +0800 Subject: [PATCH 0405/1191] Fix to comment. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index cdb7c142f6..ba3a6cf944 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -662,8 +662,8 @@ def __init__(self, def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - # x is of shape (..., num_channels); mask is of shape (...), i.e. - # it matches x except is missing the last dim. + # x is of shape (seq_len, batch_size, num_channels); mask is of shape + # (seq_len, batch_size), with True for *non*-masked positions. return predict_loss(x, self.predictor, self.proj_weight, self.name, mask) From 5bfa3e9e2c8920d07defc7ac3f67fbdc85f16bfe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Jul 2025 16:01:12 +0800 Subject: [PATCH 0406/1191] Revert num_hidden from 512 to 1024 and add bypass proj. --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ba3a6cf944..39f97b781d 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -610,6 +610,11 @@ def __init__( hidden_channels, ) + self.bypass_proj = nn.Linear( + channels, + out_channels, + ) + self.depthwise_conv = nn.Conv1d( in_channels=hidden_channels, out_channels=hidden_channels, @@ -630,11 +635,12 @@ def forward( self, x: Tensor, ) -> Tensor: + bypass = self.bypass_proj(x) x = self.in_proj(x) # (time, batch, 2*channels) x = x.permute(1, 2, 0) # (#batch, channels, time). x = self.depthwise_conv(x) x = x.permute(2, 0, 1) # (time, batch, channels) - x = self.out_proj(x) # includes activation. + x = bypass + self.out_proj(x) # includes activation. return x @@ -654,7 +660,7 @@ def __init__(self, self.register_buffer('proj_weight', scale * torch.randn(codebook_size, num_channels), persistent=True) - num_hidden = max(512, num_channels) + num_hidden = max(1024, num_channels) kernel_size = 7 self.predictor = PredictorConvModule(num_channels, num_hidden, kernel_size, codebook_size) From 7d7892ee7cc061fe8618baf2c76116e3a10c9ac0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 11:14:17 +0800 Subject: [PATCH 0407/1191] Replace LogSoftmax of CTC with SquareLogSoftmax. --- egs/librispeech/ASR/zapformer/model.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 7573778c81..f006fb8500 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -116,7 +116,7 @@ def __init__( self.ctc_output = nn.Sequential( nn.Dropout(p=0.1), ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), - nn.LogSoftmax(dim=-1), + SquareLogSoftmax(dim=-1), ) self.use_attention_decoder = use_attention_decoder @@ -578,6 +578,24 @@ def forward_reconstruction_loss(self, return loss +class SquareLogSoftmax(nn.Module): + def __init__(self, dim: int = -1, eps: float = 1.0e-05): + super().__init__() + self.dim = dim + self.eps = eps + + + def forward(self, x: Tensor): + dim = self.dim + eps = self.eps + norm = (x ** 2).sum(dim=dim, keepdim=True).clamp(min=eps) ** -0.5 + x = x.clamp(min=eps) * norm + # x**2 is the probability, we return the log of that which is 2 * log(x). The probs x**2 cannot + # sum up to more than 1, because of the normalization above. (The sum may be less than 1, if some + # x values are negative.) This ignores clamping to eps though. + return 2 * x.log() + + def smooth_l1_loss_mod(diffs: Tensor, beta: float = 1.0, norm_dims: Optional[Tuple[int]] = None): From 34de050256f44bcc609fa3bf6ba8678b63717e5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 11:36:40 +0800 Subject: [PATCH 0408/1191] Add SquareLogSoftmax to other things interpreted as logprobs. --- egs/librispeech/ASR/zapformer/model.py | 30 ++++++------------------ egs/librispeech/ASR/zipformer/joiner.py | 5 ++-- egs/librispeech/ASR/zipformer/scaling.py | 19 +++++++++++++++ 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index f006fb8500..34578a3fb2 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch import Tensor from encoder_interface import EncoderInterface -from scaling import ScaledLinear, convert_num_channels +from scaling import ScaledLinear, convert_num_channels, SquareLogSoftmax from icefall.utils import add_sos, make_pad_mask, time_warp @@ -99,11 +99,13 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_scale=0.1, + self.simple_am_proj = nn.Sequential( + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + SquareLogSoftmax(dim=-1), ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, vocab_size, initial_scale=0.1, + self.simple_lm_proj = nn.Sequential( + ScaledLinear(decoder_dim, vocab_size, initial_scale=0.1), + SquareLogSoftmax(dim=-1), ) else: @@ -578,24 +580,6 @@ def forward_reconstruction_loss(self, return loss -class SquareLogSoftmax(nn.Module): - def __init__(self, dim: int = -1, eps: float = 1.0e-05): - super().__init__() - self.dim = dim - self.eps = eps - - - def forward(self, x: Tensor): - dim = self.dim - eps = self.eps - norm = (x ** 2).sum(dim=dim, keepdim=True).clamp(min=eps) ** -0.5 - x = x.clamp(min=eps) * norm - # x**2 is the probability, we return the log of that which is 2 * log(x). The probs x**2 cannot - # sum up to more than 1, because of the normalization above. (The sum may be less than 1, if some - # x values are negative.) This ignores clamping to eps though. - return 2 * x.log() - - def smooth_l1_loss_mod(diffs: Tensor, beta: float = 1.0, norm_dims: Optional[Tuple[int]] = None): diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index 0406efe834..76ce229c13 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from scaling import ScaledLinear +from scaling import ScaledLinear, SquareLogSoftmax class Joiner(nn.Module): @@ -32,6 +32,7 @@ def __init__( self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = nn.Linear(joiner_dim, vocab_size) + self.output_log_softmax = SquareLogSoftmax(dim=-1) def forward( self, @@ -62,6 +63,6 @@ def forward( else: logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) + logit = self.output_log_softmax(self.output_linear(torch.tanh(logit))) return logit diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 39f97b781d..3808f2230e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1470,6 +1470,24 @@ def forward(self, x: Tensor) -> Tensor: return swashr_compiled(x) +class SquareLogSoftmax(nn.Module): + def __init__(self, dim: int = -1, eps: float = 1.0e-05): + super().__init__() + self.dim = dim + self.eps = eps + + + def forward(self, x: Tensor): + dim = self.dim + eps = self.eps + norm = (x ** 2).sum(dim=dim, keepdim=True).clamp(min=eps) ** -0.5 + x = x.clamp(min=eps) * norm + # x**2 is the probability, we return the log of that which is 2 * log(x). The probs x**2 cannot + # sum up to more than 1, because of the normalization above. (The sum may be less than 1, if some + # x values are negative.) This ignores clamping to eps though. + return 2 * x.log() + + class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @@ -1545,6 +1563,7 @@ def backward(ctx, ans_grad: Tensor): return x_deriv, weight_deriv, bias_deriv, None, None, None + class ActivationDropoutAndLinear(torch.nn.Module): """ This merges an activation function followed by dropout and then a nn.Linear module; From ead892739aac5f3f8a6624bea6ba3d6836b6a430 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 12:11:31 +0800 Subject: [PATCH 0409/1191] Remove special treatment of negative values. --- .../ASR/zapformer_denoise/asr_datamodule.py | 448 ++++++ .../ASR/zapformer_denoise/decode.py | 537 +++++++ .../zapformer_denoise/decode_gigaspeech.py | 1 + .../zapformer_denoise/encoder_interface.py | 1 + .../ASR/zapformer_denoise/export-onnx-ctc.py | 1 + .../export-onnx-streaming-ctc.py | 1 + .../export-onnx-streaming.py | 1 + .../ASR/zapformer_denoise/export-onnx.py | 1 + .../ASR/zapformer_denoise/export.py | 1 + .../ASR/zapformer_denoise/finetune.py | 1 + .../generate_averaged_model.py | 1 + .../ASR/zapformer_denoise/label_smoothing.py | 1 + .../ASR/zapformer_denoise/model.py | 388 +++++ .../ASR/zapformer_denoise/optim.py | 1 + .../ASR/zapformer_denoise/pretrained.py | 1 + .../ASR/zapformer_denoise/scaling.py | 1 + .../zapformer_denoise/speech_recognition.py | 229 +++ .../ASR/zapformer_denoise/subsampling.py | 297 ++++ .../ASR/zapformer_denoise/test_scaling.py | 1 + .../ASR/zapformer_denoise/train.py | 1378 +++++++++++++++++ .../ASR/zapformer_denoise/zapformer.py | 1344 ++++++++++++++++ egs/librispeech/ASR/zipformer/scaling.py | 11 +- 22 files changed, 4640 insertions(+), 6 deletions(-) create mode 100755 egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py create mode 100755 egs/librispeech/ASR/zapformer_denoise/decode.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/encoder_interface.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/export.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/finetune.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/label_smoothing.py create mode 100755 egs/librispeech/ASR/zapformer_denoise/model.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/optim.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/pretrained.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/scaling.py create mode 100755 egs/librispeech/ASR/zapformer_denoise/speech_recognition.py create mode 100644 egs/librispeech/ASR/zapformer_denoise/subsampling.py create mode 120000 egs/librispeech/ASR/zapformer_denoise/test_scaling.py create mode 100755 egs/librispeech/ASR/zapformer_denoise/train.py create mode 100644 egs/librispeech/ASR/zapformer_denoise/zapformer.py diff --git a/egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py b/egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py new file mode 100755 index 0000000000..09513afbe0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py @@ -0,0 +1,448 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + K2SpeechRecognitionDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer_denoise/decode.py b/egs/librispeech/ASR/zapformer_denoise/decode.py new file mode 100755 index 0000000000..dedf092b82 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/decode.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--num-steps", + type=int, + default=8, + help="""The number of time-steps in denoising decoding.""" + ) + + parser.add_argument( + "--eps", + type=float, + default=1.0e-04, + help="""The t value that we start from with pure noise.""" + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> List[List[str]]: + """Decode one batch and return the result as a list of sentences + (each sentence is a list of words). + + Args: + params: + The return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + The return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result as a list of list of strings (words), i.e. + a list of sentences. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + tokens = model.infer(feature, feature_lens, params.eps, params.num_steps) # list of lists of int + + hyps = [ sp.decode(t).split() for t in tokens ] # list of lists of str + + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> List[Tuple[str, List[str], List[str]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Returns list of tuples (cut_id, ref_transcript, hyp_transcript) + with types (str, List[str], List[str]). + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + + log_interval = 10 + + + results = [ ] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, + model=model, + sp=sp, + batch=batch, + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]] +): + """ + Save text produced by ASR. + """ + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results: List[Tuple[str, List[str], List[str], Tuple]], +): + """ + Save WER and per-utterance word alignments. + """ + + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}", results, enable_log=True + ) + logging.info(f"Wrote detailed error stats to {errs_filename}") + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print(f"{wer}", file=fd) + + s = f"\nFor {test_set_name}, WER is {wer}" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + params.res_dir = params.exp_dir / "decode" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + params.suffix += f"_{params.num_steps}step" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results=results, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results=results, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py new file mode 120000 index 0000000000..63b0ef617b --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py @@ -0,0 +1 @@ +../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/encoder_interface.py b/egs/librispeech/ASR/zapformer_denoise/encoder_interface.py new file mode 120000 index 0000000000..aa5d0217a8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py new file mode 120000 index 0000000000..dc14e93e75 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..3baa2b673c --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py new file mode 120000 index 0000000000..d18cb9a9a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx.py new file mode 120000 index 0000000000..f343cf7027 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/export-onnx.py @@ -0,0 +1 @@ +../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export.py b/egs/librispeech/ASR/zapformer_denoise/export.py new file mode 120000 index 0000000000..1a126ab695 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/export.py @@ -0,0 +1 @@ +../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/finetune.py b/egs/librispeech/ASR/zapformer_denoise/finetune.py new file mode 120000 index 0000000000..0e9e7989b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/finetune.py @@ -0,0 +1 @@ +../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py b/egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py new file mode 120000 index 0000000000..b65513a058 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py @@ -0,0 +1 @@ +../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/label_smoothing.py b/egs/librispeech/ASR/zapformer_denoise/label_smoothing.py new file mode 120000 index 0000000000..3690afff9d --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/label_smoothing.py @@ -0,0 +1 @@ +../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/model.py b/egs/librispeech/ASR/zapformer_denoise/model.py new file mode 100755 index 0000000000..968575ecef --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/model.py @@ -0,0 +1,388 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import k2 +import torch +import logging +import torch.nn as nn +from torch import Tensor +from scaling import ScaledLinear, convert_num_channels, SwashR +import math +from icefall.utils import make_pad_mask, time_warp + + + +class DenoisingAsrModel(nn.Module): + def __init__( + self, + #speech_embed: nn.Module, + encoder: nn.Module, + encoder_dim: int, + text_embed_dim: int, + vocab_size: int, + time_embed_dim: int, + ): + """ + TODO + """ + super().__init__() + + self.speech_scale = 0.5 + self.encoder = encoder + self.encoder_dim = encoder_dim + + # s is the time value for the speech, 0 <= s <= 1. + # t is the time value for the symbols, 0 <= t <= 1. + self.time_embed_dim = time_embed_dim + self.st_embed = nn.Sequential( + nn.Linear(time_embed_dim * 2, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim) + ) + + # randomly initialize text embedding and do not train it. + text_embed_scale = 0.25 # this will ensure that later steps still "matter". + self.text_embed = FixedEmbedding(vocab_size, text_embed_dim, scale=text_embed_scale) + + self.text_in_proj = nn.Linear(text_embed_dim, encoder_dim) + self.text_out_proj = nn.Linear(encoder_dim, text_embed_dim) + + # for now just hardcode + speech_channels = 80 + speech_subsample = 4 + self.speech_out_proj = nn.Linear(encoder_dim, + speech_channels * speech_subsample) + + self.speech_in_proj = nn.Linear(speech_channels * speech_subsample, + encoder_dim) + + + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + y_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A Tensor of dtype long, indexed [utt][symbol], padded with symbol 0 + on the right. There is no BOS or EOS symbol. + + Returns: + Returns flow-matching loss values for symbols and speech respectively. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + batch_size = x.shape[0] + assert x.shape[0] == x_lens.shape[0] == y.shape[0], (x.shape, x_lens.shape, y.shape) + + s = torch.rand(batch_size, device=x.device) # time-value for speech. + t = torch.rand(batch_size, device=x.device) # time-value for text. + + + st = self.st_embed(torch.cat((timestep_embedding(s, self.time_embed_dim), + timestep_embedding(t, self.time_embed_dim)), dim=1)) + # st: (batch_size, time_embed_dim) + + (batch_size, speech_seq_len, num_freqs) = x.shape + + device = x.device + x1 = x * self.speech_scale # scale log-mels by 0.1 to be better matched to normal distribution. + x0 = torch.randn_like(x1) + xs = (x1 * s[:, None, None]) + (x0 * (1 - s[:, None, None])) + # x1, x0, xs: (batch_size, seq_len, 80) + xV = x1 - x0 # xV means x velocity. (batch_size, speech_seq_len, 80) + + padding = (4 - (speech_seq_len % 4)) % 4 + xs = torch.nn.functional.pad(xs, (0, 0, 0, padding)) + xs = xs.reshape(batch_size, -1, 4 * num_freqs) + xs_embed = self.speech_in_proj(xs) + x_lens_embed = x_lens // 4 + + xs_embed = xs_embed.permute(1, 0, 2) # (embed_seq_len, batch_size, encoder_dim) + embed_seq_len = xs_embed.shape[0] + + with torch.amp.autocast('cuda', enabled=False): + y = randomly_pad_to_lengths(y, y_lens, torch.minimum(x_lens_embed, y_lens + y_lens // 4), embed_seq_len) + # now y: (batch_size, seq_len) + y1 = self.text_embed(y) + # now y1: (batch_size, seq_len, text_embed_dim) + y0 = torch.randn_like(y1) + yt = (y1 * t[:, None, None]) + (y0 * (1 - t[:, None, None])) + # yt: (batch_size, seq_len, text_embed_dim) + yt_embed = self.text_in_proj(yt).permute(1, 0, 2) # (embed_seq_len, batch_size, encoder_dim) + yV = y1 - y0 # yV means y velocity. (batch_size, embed_seq_len, text_embed_dim) + + encoder_in = xs_embed + yt_embed + + src_key_padding_mask = torch.arange(0, embed_seq_len, device=x.device) >= x_lens_embed.unsqueeze(-1) # (batch-size, max_x_len) + + encoder_out = self.encoder(encoder_in, st, x_lens_embed, src_key_padding_mask) + (embed_seq_len, batch_size, _encoder_dim) = encoder_out.shape + + xU = self.speech_out_proj(encoder_out) + xU = xU.permute(1, 0, 2).reshape(batch_size, embed_seq_len * 4, -1) + xU = xU[:, :speech_seq_len] # (batch_size, speech_seq_len, 80) + + # don't use x_mask in training, this will simplify inference. + # x_mask = (torch.arange(0, speech_seq_len, device=x.device) < x_lens.unsqueeze(-1)).unsqueeze(-1) + # x_mask: # (batch-size, speech_seq_len, 1). + + x_loss = ((xV - xU) ** 2).mean(dim=-1).sum() + + yU = self.text_out_proj(encoder_out) + yU = yU.permute(1, 0, 2) # (batch_size, embed_seq_len, text_embed_dim) + + #y_mask = torch.logical_not(src_key_padding_mask).unsqueeze(-1) + y_loss = ((yV - yU) ** 2).mean(dim=-1).sum() + + return x_loss, y_loss # speech_loss, text_loss + + + def infer( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + eps: float, + num_steps: int, + ) -> List[List[int]]: + """ + Does inference. Starting from random noise representing the text, does inference + for a number of steps and then converts the text representation to integers. + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + eps: + The 't' value to start inference from, e.g. 1.0e-04 + num_steps: + The number of inference steps to use. + + Returns: + Returns the inference result as a list of lists of symbols, with blanks (symbol zero) + removed. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + batch_size = x.shape[0] + assert x.shape[0] == x_lens.shape[0] + + s = torch.ones(batch_size, device=x.device) # time-value for speech is 1.0 throughout. + xs = x * self.speech_scale # scale log-mels by 0.1 to be better matched to normal distribution. + # xs is the same as x1, because s == 1.0, in inference there is no noise on the speech. + (batch_size, speech_seq_len, num_freqs) = xs.shape + padding = (4 - (speech_seq_len % 4)) % 4 + xs = torch.nn.functional.pad(xs, (0, 0, 0, padding)) + xs = xs.reshape(batch_size, -1, 4 * num_freqs) + xs_embed = self.speech_in_proj(xs) + x_lens_embed = x_lens // 4 + xs_embed = xs_embed.permute(1, 0, 2) # (embed_seq_len, batch_size, encoder_dim) + (embed_seq_len, batch_size, encoder_dim) = xs_embed.shape + src_key_padding_mask = torch.arange(0, embed_seq_len, device=x.device) >= x_lens_embed.unsqueeze(-1) # (batch-size, max_x_len) + text_embed_dim = self.text_embed.weight.shape[1] + + delta_t = (1.0 - eps) / num_steps + + yt = torch.randn(embed_seq_len, batch_size, text_embed_dim, device=x.device) # start with noise at t ~ 0 + + for step in range(num_steps): + t = torch.full((batch_size,), eps + step * delta_t, device=x.device) # time-value for text. + st = self.st_embed(torch.cat((timestep_embedding(s, self.time_embed_dim), + timestep_embedding(t, self.time_embed_dim)), dim=1)) + # st: (batch_size, time_embed_dim) + + + yt_embed = self.text_in_proj(yt) # (embed_seq_len, batch_size, encoder_dim) + encoder_in = xs_embed + yt_embed + encoder_out = self.encoder(encoder_in, st, x_lens_embed, src_key_padding_mask) + yU = self.text_out_proj(encoder_out) + + yt = yt + yU * delta_t + + + yt = yt.permute(1, 0, 2) # (batch_size, seq_len, text_embed_dim) + tokens, residual = find_closest_tokens(yt, self.text_embed.weight) + + logging.info(f"Avg residual is {residual}") + + tokens = tokens.tolist() + # remove blanks. + tokens = [ [ s for s in sent if s != 0 ] for sent in tokens ] + + return tokens + + + +class FixedEmbedding(nn.Module): + def __init__(self, vocab_size: int, embed_dim: int, scale: float = 1.0): + super().__init__() + self.register_buffer('weight', scale * torch.randn(vocab_size, embed_dim), + persistent=True) + + def forward(self, y: Tensor): + y_shape = y.shape + ans = torch.index_select(self.weight, 0, y.flatten()) + return ans.reshape(*y_shape, -1) + + + +def find_closest_tokens(y: Tensor, weights: Tensor) -> Tuple[Tensor, Tensor]: + """ + Find closest token indexes to embedding vectors. + Args: + y: (..., embed_dim), the embeddings to match to weights. + weights: (num_tokens, embed_dim), the embedding vectors for each token. + + Returns: (tokens, avg_residual) + tokens: (...), a LongTensor containing the indexes of the closest tokens + avg_residual: a LongTensor containing the average difference (rms of elements) + between embeddings and weights. + """ + yy = (y ** 2).sum(dim=-1) # (...) + ww = (weights ** 2).sum(dim=-1) # (num_tokens,) + yw = torch.matmul(y, weights.t()) # (..., num_tokens) + # (y - w) ** 2 = y**2 + w**2 - 2 yw + + residuals = yy.unsqueeze(-1) + ww - 2 * yw + residuals, tokens = torch.min(residuals, dim=-1) + + embed_dim = weights.shape[1] + return tokens, (residuals.mean() / embed_dim).sqrt() + + + +def timestep_embedding(timesteps, dim, max_period=10000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + +def randomly_pad_to_lengths(y: Tensor, + y_lens: Tensor, + x_lens: Tensor, + max_x_len: int): + """ + Randomly insert blanks (symbol 0) into the symbol-sequences in y, with lengths y_lens, so that + they have lengths x_lens. All tensor are LongTensors (dtype torch.long) + Args: + y: (batch_size, max_y_len): the symbols; all positions less than the corresponding y_lens value + are expected to be nonzero. + y_lens: the lengths of the sequences in y, we expect that 1 <= y_lens <= max_y_len + x_lens: the lengths of the sequences we want to pad to, we expect that y_lens <= x_lens <= max_x_len. + """ + # checking that each y is not longer than corresponding x. + debug = True #(__name__ == '__main__') + length_diff = x_lens - y_lens + if debug: + assert length_diff.min() >= 0 + + (batch_size, max_y_len) = y.shape + + + y_mask = torch.arange(0, max_y_len + 1, device=y.device) >= y_lens.unsqueeze(-1) # (batch-size, max_y_len) + # y_mask is True for masked, i.e. non-valid, positions + + # cut_points are points at which we divide up the interval [0..y_len-x_len] which is + # the amount by which we want to pad. We want to get y_len + 1 "padding lengths" that + # sum to y_len-x_len. We get these by taking the numbers: [ 0, , 1 , 1... ], + # multiplying by (y_len-x_len), so we have: [ 0, , y_len-x_len, y_len-x_len.. ], + # and take the differences between each one and the next, so we get: + # [ , 0, 0, ... ] and the counts add up to y_len-x_len. + # + cut_points = torch.rand(batch_size, max_y_len + 2, device=y.device) + cut_points[:, 1:].masked_fill_(y_mask, 1.0) + cut_points[:, 0] = 0.0 + cut_points = cut_points * length_diff.unsqueeze(-1) + cut_points = cut_points.sort(dim=1)[0] + cut_points = cut_points.round().to(torch.long) + num_pad = cut_points[:, 1:] - cut_points[:, :-1] + + + + num_symbols = torch.empty(batch_size, 2 * max_y_len, device=y.device, dtype=torch.long) + num_symbols[:, 1::2] = (1 - y_mask[:, :-1].to(torch.long)) # the actual symbols have length 1. + num_symbols[:, 0:-1:2] = num_pad[:, :-1] # assign the number of padding symbols for each position. + # we don't need the last padding length, it doesn't determine any symbol position. + + symbol_positions = num_symbols.cumsum(dim=1) + symbol_positions = symbol_positions[:, 0::2] + + # the "+ 1" is because the symbol_positions will actually contain, in the padding + # positions, a number equal to the corresponding values in x_lens; and this may + # be out of range in the scatter_ unless we add one padding element. + padded_symbols = torch.zeros(batch_size, max_x_len + 1, device=y.device, dtype=torch.long) + padded_symbols.scatter_(dim=1, index=symbol_positions, src=y) + padded_symbols = padded_symbols[:, :-1] # remove the one padding position + x_mask = torch.arange(0, max_x_len, device=y_lens.device) < x_lens.unsqueeze(-1) + if debug: + assert torch.all(padded_symbols == padded_symbols * x_mask) + return padded_symbols + + +def _test_find_closest_tokens(): + vocab_size = 10 + embed_dim = 30 + text_embed = FixedEmbedding(vocab_size, embed_dim) + tokens = torch.randint(0, vocab_size, (3, 4), dtype=torch.long) + + embeddings = text_embed(tokens) + embeddings = embeddings + 0.05 * torch.randn_like(embeddings) + + tokens2, residual = find_closest_tokens(embeddings, text_embed.weight) + print("Residual = ", residual) # should be around 0.05. + assert torch.all(tokens2 == tokens) + + +def _test_randomly_distribute_labels(): + y = torch.tensor([ [ 1, 2, 3, 4 ], [ 5, 6, 7, 0 ], [ 8, 9, 0, 0 ] ]) + y_lens = torch.tensor([ 4, 3, 2 ] ) + x_lens = torch.tensor([ 8, 6, 5 ]) + max_x_len = 7 + y = randomly_pad_to_lengths(y, y_lens, x_lens, max_x_len) + print("y_padded = ", y) + + + + +if __name__ == '__main__': + _test_find_closest_tokens() + for _ in range(10): + _test_randomly_distribute_labels() diff --git a/egs/librispeech/ASR/zapformer_denoise/optim.py b/egs/librispeech/ASR/zapformer_denoise/optim.py new file mode 120000 index 0000000000..207eecfcda --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/pretrained.py b/egs/librispeech/ASR/zapformer_denoise/pretrained.py new file mode 120000 index 0000000000..70ad71ffc6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/pretrained.py @@ -0,0 +1 @@ +../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/scaling.py b/egs/librispeech/ASR/zapformer_denoise/scaling.py new file mode 120000 index 0000000000..58e4b0a0fe --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/speech_recognition.py b/egs/librispeech/ASR/zapformer_denoise/speech_recognition.py new file mode 100755 index 0000000000..dd069cf3da --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/speech_recognition.py @@ -0,0 +1,229 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer_denoise/subsampling.py b/egs/librispeech/ASR/zapformer_denoise/subsampling.py new file mode 100644 index 0000000000..03e0319feb --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/subsampling.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Tuple, Optional + +import torch +from scaling import ( + ScaleLimiter, + ScaledLinear, + ExpNorm, + Dropout3, + FloatLike, + ScaledConv2d, + ScaleGrad, + ScheduledFloat, + SwashL, + SwashR, + Whiten, +) +from torch import Tensor, nn + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + ): + super().__init__() + self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + hidden_channels = channels * hidden_ratio + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=self.padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1, + ) + + self.activation = SwashL() + + self.pointwise_conv2 = nn.Conv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + ) + + + def forward( + self, x: Tensor, + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + + return x + + def streaming_forward( + self, + x: Tensor, + cached_left_pad: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - The returned value has the same shape as x. + - Updated cached_left_pad. + """ + padding = self.padding + + # The length without right padding for depth-wise conv + T = x.size(2) - padding[0] + + bypass = x[:, :, :T, :] + + # Pad left side + assert cached_left_pad.size(2) == padding[0], ( + cached_left_pad.size(2), + padding[0], + ) + x = torch.cat([cached_left_pad, x], dim=2) + # Update cached left padding + cached_left_pad = x[:, :, T : padding[0] + T, :] + + # depthwise_conv + x = torch.nn.functional.conv2d( + x, + weight=self.depthwise_conv.weight, + bias=self.depthwise_conv.bias, + padding=(0, padding[1]), + groups=self.depthwise_conv.groups, + ) + x = self.pointwise_conv1(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + return x, cached_left_pad + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: FloatLike = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + self.in_channels = in_channels + super().__init__() + + # The ScaleGrad module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ScaleGrad(0.2), + SwashR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + SwashR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + padding=0, + ), + SwashR(), + ) + + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + + # (in_channels-3)//4 + self.out_width = (in_channels-3) // 4 + self.layer3_channels = layer3_channels + + # scale it up a bit, else the output is quite small. + self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, + initial_scale=4.0) + + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), + prob=(0.025, 0.25), + grad_scale=0.02, + ) + + # max_log_eps=0.0 is to prevent both eps and the output of self.out from + # getting large, there is an unnecessary degree of freedom. + self.out_norm = ExpNorm(out_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def pad(self, x: torch.Tensor) -> Tensor: + (N, T, idim) = x.shape + + + right_pad = (4 * ((T + 3) // 4)) - T + # first, pad to be a multiple of 4 frames. this is so we can later reconstruct at + # least the original number of frames. + + # next, we have to add 5 frames in order to get, finally (T + right_pad) // 4 frames. + left_pad = 3 + right_pad = 2 + right_pad + return torch.nn.functional.pad(x, (0, 0, left_pad, right_pad)) + + + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, (T-3)//4, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (batch_size, time, ideim) + x = self.pad(x) + # define x shape now as (N, T, idim) with T being the padded shape. + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, (T-5)//4, (idim-3)//4) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, (T-5)//4, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, (T-5)//4, odim) + x = self.out_whiten(x) + x = self.out_norm(x) + x = self.dropout(x) + + # the "+ 3" reflects the rounding-up-to-a-multiple-of-4 that we do at + # the start of self.pad(). We would, without self.pad() need to have a + # "-5" here and the adding 5 frames in self.pad() cancels that out. + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens + 3) // 4 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens + 3) // 4 + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) + + return x, x_lens diff --git a/egs/librispeech/ASR/zapformer_denoise/test_scaling.py b/egs/librispeech/ASR/zapformer_denoise/test_scaling.py new file mode 120000 index 0000000000..b776da79a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/test_scaling.py @@ -0,0 +1 @@ +../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/train.py b/egs/librispeech/ASR/zapformer_denoise/train.py new file mode 100755 index 0000000000..dba253163d --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/train.py @@ -0,0 +1,1378 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import DenoisingAsrModel +from optim import Sched3, TransformedAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zapformer import Zapformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def get_adjusted_lr_batches(params: AttributeDict) -> float: + # returns an adjusted form of the "lr_batches" parameter used to set the learning + # rate in the Sched3 scheduler. + # We want the final LR to be based on the geometric mean of "how much data we + # have seen" and "how many batches we have seen". + # an easier way to look at it is this: the formula for learning rate depends + # on (cur_batch / lr_batches). if we write this as: + # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches + # then the numerator is a geometric mean of "how many batches we have seen" + # and "how much data we have seen". We can achieve this by setting + # lr_batches = params.lr_batches * (duration_ratio ** -0.5). + duration_ratio = (params.max_duration * params.world_size) / params.ref_duration + lr_batches = params.lr_batches * (duration_ratio ** -0.5) + logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") + return lr_batches + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="8,8,8", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,1,1", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--base-dim", + type=int, + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." + ) + + parser.add_argument( + "--embed-multiple", + type=int, + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", + ) + + parser.add_argument( + "--text-embed-dim", + type=int, + default=8, + help="Dim of text embeddings.", + ) + + parser.add_argument( + "--speech-loss-scale", + type=float, + default=1.0, + help="Loss scale on the speech part of the loss", + ) + + parser.add_argument( + "--time-embed-multiple", + type=int, + default=4, + help="Multiply by base-dim to determine dimension of time embedding." + ) + + + parser.add_argument( + "--feedforward-multiple", + type=str, + default="3", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,8,4", + help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-multiple", + type=str, + default="6,6,6", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--debug-interval", + type=int, + default=10, + help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for + finding parts of the network that are diverging or not well trained. + """ + ) + + parser.add_argument( + "--dump-debug-interval", + type=int, + default=0, + help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they + are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file + opened, seeked-in and closed for each scalar that is written). + """ + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer_denoise/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=17500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on pruned loss (for transducer) and the reconstruction and prediction + losses. Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_speech_embed(params: AttributeDict) -> nn.Module: + # speech_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + speech_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=lookup(params, "embed_dim"), + dropout=0.0, + ) + return speech_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zapformer( + input_dim=lookup(params, "embed_dim"), + time_embed_dim=lookup(params, "time_embed_dim"), + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + pos_head_dim=lookup(params, "pos_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), + pos_dim=params.pos_dim, + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + cnn_module_kernel=lookup(params, "cnn_module_kernel"), + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "decoder_dim"), + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + + + +def get_model(params: AttributeDict) -> nn.Module: + + #speech_embed = get_speech_embed(params) + encoder = get_encoder_model(params) + + + model = DenoisingAsrModel( + #speech_embed=speech_embed, + encoder=encoder, + encoder_dim=lookup(params, "embed_dim"), # see embed-multiple + text_embed_dim=lookup(params, "text_embed_dim"), + vocab_size=params.vocab_size, + time_embed_dim=lookup(params, "time_embed_dim") # see time-embed-multiple + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + x = batch["inputs"] + # at entry, feature is (N, T, C) + assert x.ndim == 3 + x = x.to(device) + + supervisions = batch["supervisions"] + x_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) # list of lists. + y_lens = [ len(sent) for sent in y ] + max_y_len = max(y_lens) + y = [ sent + [ 0 ] * (max_y_len - len(sent)) for sent in y ] + y = torch.tensor(y).to(device) + y_lens = torch.tensor(y_lens).to(device) + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + nframes = (x_lens // params.subsampling_factor).sum().item() + info["frames"] = nframes + + with torch.set_grad_enabled(is_training): + speech_loss, text_loss = model(x, x_lens, y, y_lens) + # (speech_loss - 2 * nframes).relu() is to prevent it from completely ignoring the speech loss. + loss = params.speech_loss_scale * speech_loss + (speech_loss - (2.0 * nframes)).relu() + text_loss + + + assert loss.requires_grad == is_training + + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["text_loss"] = text_loss.detach().cpu().item() + info["speech_loss"] = speech_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + if params.debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + optimizer.write_debug_info(summary_writer=tb_writer) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + params.vocab_size = sp.get_piece_size() + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + debug_interval=params.debug_interval, + ) + + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and False: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[nn.Module] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer_denoise/zapformer.py b/egs/librispeech/ASR/zapformer_denoise/zapformer.py new file mode 100644 index 0000000000..e9839ff451 --- /dev/null +++ b/egs/librispeech/ASR/zapformer_denoise/zapformer.py @@ -0,0 +1,1344 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + OrthogonalLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + ScaleLimiter, + ActivationDropoutAndLinear, + ExpNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zapformer(nn.Module): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + time_embed_dim: an integer giving the dimension of the time embeddings provided + to the network. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + def __init__( + self, + input_dim: int, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + time_embed_dim: int = 256, + num_encoder_layers: Union[int, Tuple[int]] = 4, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_multiple: Union[int, Tuple[int]] = 4, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super().__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_multiple = _to_tuple(feedforward_multiple) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + # each one will be ZapformerEncoder or OrthogonalDownsample or OrthogonalUpsample + encoders = [] + + num_encoders = len(downsampling_factor) + cur_downsample = 1 + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + + def set_downsample_factor(cur_downsample, ds): + while cur_downsample < ds: + # need to downsample + encoders.append(OrthogonalDownsample(channels=input_dim * cur_downsample, + proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) + cur_downsample *= 2 + while cur_downsample > ds: + encoders.append(OrthogonalUpsample(channels=input_dim * cur_downsample, + proj_dim=min(input_dim * cur_downsample, max_proj_dim))) + cur_downsample //= 2 + return cur_downsample + + for i in range(num_encoders): + cur_downsample = set_downsample_factor(cur_downsample, downsampling_factor[i]) + + encoder_layer = ZapformerEncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_multiple=feedforward_multiple[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZapformerEncoder( + encoder_layer, + num_encoder_layers[i], + dim=cur_downsample*input_dim, + pos_dim=pos_dim, + time_embed_dim=time_embed_dim, + ) + encoder.encoder_index = i + encoders.append(encoder) + + cur_downsample = set_downsample_factor(cur_downsample, 1) + + self.encoders = nn.ModuleList(encoders) + + + def forward( + self, + x: Tensor, + time_embed: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, encoder_dim). + time_embed: + The timestep-embedding tensor. Its shape is (batch_size, time_embed_dim) + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return embeddings with the same shape as x: (seq_len, batch_size, encoder_dim) + """ + orig_seq_len = x.shape[0] + + def truncate(x, downsampling_factor): + max_len = (orig_seq_len + downsampling_factor - 1) // downsampling_factor + return x[:max_len] if x.shape[0] > max_len else x + + + for module in self.encoders: + if isinstance(module, ZapformerEncoder): + i = module.encoder_index # was set in this class's __init__ function. + ds = self.downsampling_factor[i] + x = truncate(x, ds) + x = module( + x, + time_embed, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + ) + else: + x = module(x) + + x = x[:orig_seq_len] + return x + + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + +class ZapformerEncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_multiple: determines the hidden dimension of the feedforward module + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_multiple: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.75)), + ) -> None: + super(ZapformerEncoderLayer, self).__init__() + self.embed_dim = embed_dim + self.name = None # will be set from training loop + + self.randomize_scale = copy.deepcopy(randomize_scale) + # self.bypass implements layer skipping as well as learnable scale on a residual term; see its default values. + self.residual = ResidualModule( + embed_dim, + ) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] + + feedforward_dim = embed_dim * feedforward_multiple + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + + self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + for _ in range(2) ] + + self.scale_limiter = ScaleLimiter(max_var=2.0) + + self.norm = ExpNorm(embed_dim) + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + src = src + self.self_attn1(src, attn_weights) + + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn2(src, attn_weights) + + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward3(src) + + src = self.residual(src_orig, src) + + src = self.scale_limiter(src) + + src = self.norm(src) + + return src + + +class ZapformerEncoder(nn.Module): + r"""ZapformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZapformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). + pos_dim: the dimension for the relative positional encoding +dropout: + + Examples:: + >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = ZapformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + + + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dim: int, + pos_dim: int, + time_embed_dim: int, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.0, length_factor=1.0 + ) + self.name = None + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.residual = ResidualModule(encoder_layer.embed_dim) + + self.time_embed = ScaledLinear(time_embed_dim, encoder_layer.embed_dim, initial_scale=0.1) + + #bypass_dim = dim - encoder_layer.embed_dim + self.copy_bypass = Identity() + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(3.0), + prob=(1, 1), + grad_scale=0.025, + ) + + + + def forward( + self, + src: Tensor, + time_embed: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. + time_embed: the time embedding, shape: (batch_size, seq_len) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + + num_channels = src.shape[-1] + layer_dim = self.layers[0].embed_dim + if num_channels > layer_dim: + src, bypass = src[..., :layer_dim], src[..., layer_dim:] + + + src_orig = src + src = src + self.time_embed(time_embed) + for i, mod in enumerate(self.layers): + src = mod( + src, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + # randomize_factor can be viewed as a simple version of an + # importance-sampling factor. + + src = self.residual(src_orig, src) + src = self.whiten(src) + + if num_channels > layer_dim: + bypass = self.copy_bypass(bypass) + src = torch.cat((src, bypass), dim=-1) + + return src + + +class ResidualModule(nn.Module): + """ + An nn.Module that implements a learnable residual scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + function_scale_min: FloatLike = 0.1, + ): + super().__init__() + self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.function_scale_min = copy.deepcopy(function_scale_min) + + + def _get_scales(self): + function_scale = self.function_scale + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + function_scale = limit_param_value( + function_scale, min=float(self.function_scale_min), max=1.0, + ) + residual_scale = 1.0 - function_scale + return residual_scale, function_scale + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + residual_scale, function_scale = self._get_scales() + return residual_scale * src_orig + function_scale * src + + + +class OrthogonalDownsample(torch.nn.Module): + """ + Does downsampling with an orthogonal matrix, by a factor of two. Projection is initialized + in a special way and enforced to be orthogonal. + + Args: + channels: the number of input channels; the num output channels will be twice this + proj_dim: the number of channels, after combining 2 frames by interpolating their channels + as [ a b a b, .. ] that will actually be projected; the rest are just copied. + proj_dim=2 * channels would mean all channels are projected in a learned way + causal: True for causal systems, only affects error messages as requires even + input num frames. + """ + def __init__( + self, channels: int, proj_dim: int, causal: bool = False, + ): + super().__init__() + assert proj_dim <= channels * 2 + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + self.causal = causal + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + + if seq_len % 2 == 1: + if torch.jit.is_tracing(): + assert ( + not self.causal + ), f"pad should be zero for exporting streaming models. Given {pad}" + src = torch.cat((src, src[-1:]), dim=0) + seq_len += 1 + + # the following will place each 2 frames of a particular channel right after + # each other as if they were two different channels. + src = torch.stack((src[0::2], src[1::2]), dim=-1) + src = src.reshape(seq_len // 2, batch_size, in_channels * 2) + proj_channels = self.proj.weight.shape[0] + if proj_channels < in_channels * 2: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + return src + +class OrthogonalUpsample(torch.nn.Module): + """ + A very simple form of upsampling with an orthogonal matrix. + + proj_dim: the number of channels that will actually be projected; the rest are just copied. + proj_dim=channels would mean all channels are projected in a learned way + + """ + def __init__(self, channels: int, proj_dim: int): + super().__init__() + assert proj_dim <= channels + # gradually make smaller and then turn off the non-orthognality penalty. + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*2), batch_size, num_channels // 2) + """ + proj_channels = self.proj.weight.shape[0] + (seq_len, batch_size, in_channels) = src.shape + + if proj_channels < in_channels: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + + src = torch.stack((src[..., 0::2], src[..., 1::2]), + dim=1) # (seq_len, 2, batch_size, in_channels // 2) + src = src.reshape(seq_len * 2, batch_size, in_channels // 2) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) + self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125 * query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(k) # does nothing in the forward pass. [this may not really be needed due to the orthogonality constraint.] + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + if True: + # position scores. + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < float(self.attn_score_penalty_prob): + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=float(self.attn_score_limit), penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, + bias=True, out_groups=num_heads) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + f = max(1.0, embed_dim / (num_heads * value_head_dim)) + # the whitening metric cannot be less than f because of the rank imposed + # by the bottleneck. the final whitening limit will be (2.0*3.0) times f, + # i.e. 6 times greater than the mathematical smallest value it can have. + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(f * 2.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zapformer model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwashL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.5, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zapformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwashR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + seq_len = 20 + # Just make sure the forward pass runs. + + input_dim = 50 + time_embed_dim = 64 + + c = Zapformer( + input_dim=input_dim, + encoder_dim=(64, 96), + time_embed_dim=time_embed_dim, + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + + batch_size = 6 # make it even, as PredictLoss requires even batch size. + seq_len = 21 + # Just make sure the forward pass runs. + time_embed = torch.randn(batch_size, time_embed_dim) + + f = c( + torch.randn(seq_len, batch_size, input_dim), + time_embed, + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f.sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, input_dim), + time_embed, + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3808f2230e..58095ef9af 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1480,12 +1480,11 @@ def __init__(self, dim: int = -1, eps: float = 1.0e-05): def forward(self, x: Tensor): dim = self.dim eps = self.eps - norm = (x ** 2).sum(dim=dim, keepdim=True).clamp(min=eps) ** -0.5 - x = x.clamp(min=eps) * norm - # x**2 is the probability, we return the log of that which is 2 * log(x). The probs x**2 cannot - # sum up to more than 1, because of the normalization above. (The sum may be less than 1, if some - # x values are negative.) This ignores clamping to eps though. - return 2 * x.log() + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x_sum = (x ** 2).sum(dim=dim, keepdim=True).clamp(min=eps) + x = (x ** 2).clamp(min=eps*eps) / x_sum + return x.log() From 731f2707a6171f4a4e2fda471d63248c752db843 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:15:52 +0800 Subject: [PATCH 0410/1191] Change how eps is treated and make it much larger. --- egs/librispeech/ASR/zapformer/model.py | 6 +++--- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 34578a3fb2..5cc1aaf022 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -100,11 +100,11 @@ def __init__( self.joiner = joiner self.simple_am_proj = nn.Sequential( - ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + ScaledLinear(encoder_dim, vocab_size), SquareLogSoftmax(dim=-1), ) self.simple_lm_proj = nn.Sequential( - ScaledLinear(decoder_dim, vocab_size, initial_scale=0.1), + ScaledLinear(decoder_dim, vocab_size), SquareLogSoftmax(dim=-1), ) @@ -117,7 +117,7 @@ def __init__( # Modules for CTC head self.ctc_output = nn.Sequential( nn.Dropout(p=0.1), - ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + ScaledLinear(encoder_dim, vocab_size), SquareLogSoftmax(dim=-1), ) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 58095ef9af..7be539b5c2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1471,7 +1471,7 @@ def forward(self, x: Tensor) -> Tensor: class SquareLogSoftmax(nn.Module): - def __init__(self, dim: int = -1, eps: float = 1.0e-05): + def __init__(self, dim: int = -1, eps: float = 1.0e-03): super().__init__() self.dim = dim self.eps = eps @@ -1482,8 +1482,8 @@ def forward(self, x: Tensor): eps = self.eps with torch.amp.autocast('cuda', enabled=False): x = x.to(torch.float) - x_sum = (x ** 2).sum(dim=dim, keepdim=True).clamp(min=eps) - x = (x ** 2).clamp(min=eps*eps) / x_sum + dim = x.shape[-1] + x = ((x ** 2) + eps/dim) / (x_sum + eps) return x.log() From c747572500b0837f5e710d021c34dc05598f311b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:23:16 +0800 Subject: [PATCH 0411/1191] bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7be539b5c2..aa1a963007 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1482,8 +1482,9 @@ def forward(self, x: Tensor): eps = self.eps with torch.amp.autocast('cuda', enabled=False): x = x.to(torch.float) - dim = x.shape[-1] - x = ((x ** 2) + eps/dim) / (x_sum + eps) + channels = x.shape[dim] + x_sq = x ** 2 + x = (x_sq + eps/channels) / (x_sq.sum(dim=dim, keepdim=True) + eps) return x.log() From 11b165f44d2d12cda4e3f5ed9af5f55c706727f3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:24:22 +0800 Subject: [PATCH 0412/1191] Change for denoise_branch14, will merge there. Reverts speech_scale to 0.1 and sets t in 0.5..1. --- egs/librispeech/ASR/zapformer_denoise/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer_denoise/model.py b/egs/librispeech/ASR/zapformer_denoise/model.py index 968575ecef..4452d5a61f 100755 --- a/egs/librispeech/ASR/zapformer_denoise/model.py +++ b/egs/librispeech/ASR/zapformer_denoise/model.py @@ -44,7 +44,7 @@ def __init__( """ super().__init__() - self.speech_scale = 0.5 + self.speech_scale = 0.1 self.encoder = encoder self.encoder_dim = encoder_dim @@ -101,7 +101,7 @@ def forward( batch_size = x.shape[0] assert x.shape[0] == x_lens.shape[0] == y.shape[0], (x.shape, x_lens.shape, y.shape) - s = torch.rand(batch_size, device=x.device) # time-value for speech. + s = torch.empty(batch_size, device=x.device).uniform_(0.5, 1.0) # time-value for speech. only have >= 0.5 t = torch.rand(batch_size, device=x.device) # time-value for text. From 12f6065c5ec01472f346cc1b0f4577b63c95bebd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:35:07 +0800 Subject: [PATCH 0413/1191] Reverting model.py to version from branch deterministic_invertible919conv, removing SquareLogSoftmax. --- egs/librispeech/ASR/zapformer/model.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 5cc1aaf022..7573778c81 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch import Tensor from encoder_interface import EncoderInterface -from scaling import ScaledLinear, convert_num_channels, SquareLogSoftmax +from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp @@ -99,13 +99,11 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = nn.Sequential( - ScaledLinear(encoder_dim, vocab_size), - SquareLogSoftmax(dim=-1), + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.1, ) - self.simple_lm_proj = nn.Sequential( - ScaledLinear(decoder_dim, vocab_size), - SquareLogSoftmax(dim=-1), + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.1, ) else: @@ -117,8 +115,8 @@ def __init__( # Modules for CTC head self.ctc_output = nn.Sequential( nn.Dropout(p=0.1), - ScaledLinear(encoder_dim, vocab_size), - SquareLogSoftmax(dim=-1), + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + nn.LogSoftmax(dim=-1), ) self.use_attention_decoder = use_attention_decoder From 2f727d9eefef0cfb5d79a0a0497c1b3605eccfc0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:42:51 +0800 Subject: [PATCH 0414/1191] Add fangjun's code for mel warping. --- icefall/exp_augment.py | 186 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 695ecd604a..e13c31103e 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -236,6 +236,191 @@ def load_state_dict(self, state_dict: Dict[str, Any]): +def hz_to_mel(hz: torch.Tensor): + return 1127.0 * torch.log(1 + hz / 700) + + +def mel_to_hz(mel: torch.Tensor): + return 700 * ((mel / 1127).exp() - 1) + + +def compute_mel_normalized_indexes( + low_freq_hz: float, + high_freq_hz: float, + sample_rate_hz: float, + num_mel_bins: float, + shift: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Return a tuple containing normalized indexes. + + - The first tensor is for expansion, i.e., map the second-to-last + bin to the last bin + + - The second tensor is for contraction, i.e., map the last bin to + the second-to-last bin + """ + nyquist = sample_rate_hz * 0.5 + if high_freq_hz <= 0: + high_freq_hz = nyquist + high_freq_hz + + assert 0 <= low_freq_hz < high_freq_hz <= nyquist, ( + low_freq_hz, + high_freq_hz, + nyquist, + sample_rate_hz, + ) + assert num_mel_bins > 1, num_mel_bins + + low_high_mel = hz_to_mel( + torch.tensor([low_freq_hz, high_freq_hz], dtype=torch.float32) + ) + + # divided by num_mel_bins + 1 to match the one used in Kaldi + mel_freq_delta = (low_high_mel[1] - low_high_mel[0]) / (num_mel_bins + 1) + + # the formulate to compute the mel tensor below is from Kaldi + mel = low_high_mel[0] + mel_freq_delta * torch.arange(num_mel_bins) + + hz = mel_to_hz(mel) + + expansion_scale = hz[-1] / hz[-1 - shift] # e.g. 1.0338 + contraction_scale = 1 / expansion_scale # e.g., 0.9673 + + mel_expanded = hz_to_mel(hz * expansion_scale) + mel_contracted = hz_to_mel(hz * contraction_scale) + + mel_expanded_indexes = (mel_expanded - low_high_mel[0]) / mel_freq_delta + mel_contracted_indexes = (mel_contracted - low_high_mel[0]) / mel_freq_delta + + mel_expanded_normalized_indexes = mel_expanded_indexes * 2 / (num_mel_bins - 1) - 1 + + mel_contracted_normalized_indexes = ( + mel_contracted_indexes * 2 / (num_mel_bins - 1) - 1 + ) + + return mel_expanded_normalized_indexes, mel_contracted_normalized_indexes + + +class MelWarp(torch.nn.Module): + def __init__( + self, + low_freq_hz: float, + high_freq_hz: float, + sample_rate_hz: float, + num_mel_bins: int, + p: float, + max_shift: int = 1, + ): + super().__init__() + + assert 0 <= p <= 1, p + assert 1 <= max_shift < num_mel_bins - 1 + + indexes = [] + for i in range(1, max_shift + 1): + expansion_indexes, contraction_indexes = compute_mel_normalized_indexes( + low_freq_hz=low_freq_hz, + high_freq_hz=high_freq_hz, + sample_rate_hz=sample_rate_hz, + num_mel_bins=num_mel_bins, + shift=i, + ) + indexes.append(expansion_indexes) + indexes.append(contraction_indexes) + + self.indexes = torch.stack(indexes, dim=0) + + self.num_mel_bins = num_mel_bins + self.p = p + + def forward(self, features: torch.Tensor) -> torch.Tensor: + B, T, C = features.shape + assert C == self.num_mel_bins, (C, self.num_mel_bins) + + device = features.device + + features = features.permute(0, 2, 1) + + # grid sample requires (N,C,H,W) input + # we treat the feature axis as h, the time axis as w + # and use 1 for the channel in NCHW + + h = torch.linspace(-1, 1, C)[None, :, None].expand(B, C, T).to(device) + + # select a different index for each audio in the batch + # where each index corresponds to a shift + index = torch.randint( + low=0, high=self.indexes.shape[0], size=(B,), dtype=torch.int64 + ) + + warped_indexes = self.indexes[index][:, :, None].expand(B, C, T).to(device) + + h_positions = torch.where( + torch.rand(B, 1, 1).expand_as(features) < self.p, + warped_indexes, + h, + ) + + w = torch.linspace(-1, 1, T)[None, None, :].expand(B, C, T).to(device) + + grid = torch.stack([w, h], axis=-1) + + features = torch.nn.functional.grid_sample( + features.unsqueeze(1), + grid, + mode="bicubic", + padding_mode="border", + align_corners=True, + ) + return features.squeeze(1).permute(0, 2, 1) + + +def _test_grid_sample(): + f = torch.rand(50, 20, 80) # (batch, time, features) + B, T, C = f.shape + + h = torch.linspace(-1, 1, C)[None, :, None].expand(B, C, T) + w = torch.linspace(-1, 1, T)[None, None, :].expand(B, C, T) + # w is x + # h is y + grid = torch.stack([w, h], axis=-1) + f2 = [] + for aligned in [True, False]: + f2.append( + torch.nn.functional.grid_sample( + f.permute(0, 2, 1).unsqueeze(1), + grid, + mode="bicubic", + padding_mode="border", + align_corners=aligned, + ) + .squeeze(1) + .permute(0, 2, 1) + ) + print("align_corners=true", (f - f2[0]).abs().max()) # aligned true + print("align_corners=false", (f - f2[1]).abs().max()) # aligned false + + +def _test_mel_warp(): + # The parameters used in testing are default values in lhotse + mel_warp = MelWarp( + low_freq_hz=20, + high_freq_hz=-400, + sample_rate_hz=16000, + num_mel_bins=80, + p=1, + max_shift=4, + ) + + f0 = torch.rand(2, 20, 80) * 10 + f1 = mel_warp(f0) + + assert f0.shape == f1.shape + print((f0 - f1).abs().max()) + + + def _test_exp_augment(): for n in [ 0, 1 ]: #device = 'cuda' @@ -285,3 +470,4 @@ def _test_exp_augment(): if __name__ == '__main__': _test_exp_augment() + _test_mel_warp() From 9c189e74bd19c4a254f417b26cfb34f5a0f5fa46 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:44:23 +0800 Subject: [PATCH 0415/1191] Use fangjun's code for mel warping in model.py --- egs/librispeech/ASR/zapformer/model.py | 50 +++++++++++++++++++------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 7573778c81..22b56af281 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -87,6 +87,15 @@ def __init__( self.encoder_embed = encoder_embed self.encoder = encoder + self.mel_warp = MelWarp( + low_freq_hz=20, + high_freq_hz=-400, + sample_rate_hz=16000, + num_mel_bins=80, + p=0.9, + max_shift=4) + + self.use_transducer = use_transducer if use_transducer: # Modules for Transducer head @@ -432,22 +441,37 @@ def forward( B = batch_size // num_copies x = x.reshape(num_copies, B, seq_len, num_channels) - # Apply time warping. First append the copies on the channel - # dimension so all copies get the exact same time-warping. - x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) - - assert supervision_segments is not None - x = time_warp( - x, - time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments[:B], - ) - x = x.reshape(B, seq_len, num_copies, num_channels) - x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + time_warp = True + mel_warp = True + if time_warp: + # Apply time warping. First append the copies on the channel + # dimension so all copies get the exact same time-warping. + x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments[:B], + ) + x = x.reshape(B, seq_len, num_copies, num_channels) + x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + + if mel_warp: + # Apply mel warping. First append the copies on the sequence + # dimension so all copies of the data get the exact same + # mel-warping. (this is done mostly for purposes of the reconstruction + # loss). + x = x.permute(1, 0, 2, 3) # (B, num_copies, seq_len, num_channels) + x = x.reshape(B, num_copies * seq_len, num_channels) + + x = self.mel_warp(x) + x = x.reshape(B, num_copies, seq_len, num_channels) + x = x.permute(1, 0, 2, 3) # (num_copies, B, seq_len, num_channels) # x_no_specaug is several repeats of the 1st copy of the data, which # is the one not augmented with Musan. But it does have time - # warping. + # warping and mel warping. x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( B * (num_copies - 1), seq_len, num_channels) From e9f6a7ccc281ff8e334196f0dbb4d560fbac669e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:55:30 +0800 Subject: [PATCH 0416/1191] Bug fix from fangjun. --- icefall/exp_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index e13c31103e..901a91efb1 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -364,7 +364,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: w = torch.linspace(-1, 1, T)[None, None, :].expand(B, C, T).to(device) - grid = torch.stack([w, h], axis=-1) + grid = torch.stack([w, h_positions], axis=-1) features = torch.nn.functional.grid_sample( features.unsqueeze(1), From 5364cc0a3459adca0805a0db9619e87ee5332442 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 16:56:32 +0800 Subject: [PATCH 0417/1191] Fix missing import --- egs/librispeech/ASR/zapformer/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 22b56af281..0f5bbbf61c 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -26,6 +26,7 @@ from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp +from icefall.exp_augment import MelWarp class AsrModel(nn.Module): From 5c3fc5336637381a43f25f07f83985028bf0ab6a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 17:12:12 +0800 Subject: [PATCH 0418/1191] Fix variable naming bug --- egs/librispeech/ASR/zapformer/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 0f5bbbf61c..f0659f66e9 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -442,9 +442,9 @@ def forward( B = batch_size // num_copies x = x.reshape(num_copies, B, seq_len, num_channels) - time_warp = True - mel_warp = True - if time_warp: + do_time_warp = True + do_mel_warp = True + if do_time_warp: # Apply time warping. First append the copies on the channel # dimension so all copies get the exact same time-warping. x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) @@ -458,7 +458,7 @@ def forward( x = x.reshape(B, seq_len, num_copies, num_channels) x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) - if mel_warp: + if do_mel_warp: # Apply mel warping. First append the copies on the sequence # dimension so all copies of the data get the exact same # mel-warping. (this is done mostly for purposes of the reconstruction From 3aa24306119c00e067fbbd94a53be1a58c1c4d86 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 17:16:02 +0800 Subject: [PATCH 0419/1191] Fix bug RE devices --- icefall/exp_augment.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/icefall/exp_augment.py b/icefall/exp_augment.py index 901a91efb1..1bfb97e576 100644 --- a/icefall/exp_augment.py +++ b/icefall/exp_augment.py @@ -329,7 +329,7 @@ def __init__( indexes.append(expansion_indexes) indexes.append(contraction_indexes) - self.indexes = torch.stack(indexes, dim=0) + self.register_buffer('indexes', torch.stack(indexes, dim=0)) self.num_mel_bins = num_mel_bins self.p = p @@ -346,23 +346,23 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: # we treat the feature axis as h, the time axis as w # and use 1 for the channel in NCHW - h = torch.linspace(-1, 1, C)[None, :, None].expand(B, C, T).to(device) + h = torch.linspace(-1, 1, C, device=device)[None, :, None].expand(B, C, T).to(device) # select a different index for each audio in the batch # where each index corresponds to a shift index = torch.randint( - low=0, high=self.indexes.shape[0], size=(B,), dtype=torch.int64 + low=0, high=self.indexes.shape[0], size=(B,), dtype=torch.int64, device=device, ) warped_indexes = self.indexes[index][:, :, None].expand(B, C, T).to(device) h_positions = torch.where( - torch.rand(B, 1, 1).expand_as(features) < self.p, + torch.rand(B, 1, 1, device=device).expand_as(features) < self.p, warped_indexes, h, ) - w = torch.linspace(-1, 1, T)[None, None, :].expand(B, C, T).to(device) + w = torch.linspace(-1, 1, T, device=device)[None, None, :].expand(B, C, T) grid = torch.stack([w, h_positions], axis=-1) From 4b44127d8623fc8c72e758bf9b299355a2fafcf5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Jul 2025 18:26:54 +0800 Subject: [PATCH 0420/1191] Remove something I had mistakenly left in. --- egs/librispeech/ASR/zipformer/joiner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index 76ce229c13..0406efe834 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from scaling import ScaledLinear, SquareLogSoftmax +from scaling import ScaledLinear class Joiner(nn.Module): @@ -32,7 +32,6 @@ def __init__( self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = nn.Linear(joiner_dim, vocab_size) - self.output_log_softmax = SquareLogSoftmax(dim=-1) def forward( self, @@ -63,6 +62,6 @@ def forward( else: logit = encoder_out + decoder_out - logit = self.output_log_softmax(self.output_linear(torch.tanh(logit))) + logit = self.output_linear(torch.tanh(logit)) return logit From d5405fdc97ab7ff5b1a054c105e0cbb012ceb78a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Jul 2025 13:34:49 +0800 Subject: [PATCH 0421/1191] Reduce max_shift of MelWarp from 4 to 1 and p from .9 to .6. --- egs/librispeech/ASR/zapformer/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index f0659f66e9..af1ec9515c 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -93,8 +93,8 @@ def __init__( high_freq_hz=-400, sample_rate_hz=16000, num_mel_bins=80, - p=0.9, - max_shift=4) + p=0.666, + max_shift=1) self.use_transducer = use_transducer From 3baaa144f65a125b299ad581212eec9ead5fff15 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Jul 2025 18:50:04 +0800 Subject: [PATCH 0422/1191] Turn off amp in time warping and frequency warping. --- egs/librispeech/ASR/zapformer/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index af1ec9515c..75d3c1a2b8 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -450,11 +450,12 @@ def forward( x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) assert supervision_segments is not None - x = time_warp( - x, - time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments[:B], - ) + with torch.amp.autocast('cuda', enabled=False): + x = time_warp( + x.to(torch.float), + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments[:B], + ) x = x.reshape(B, seq_len, num_copies, num_channels) x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) @@ -466,7 +467,8 @@ def forward( x = x.permute(1, 0, 2, 3) # (B, num_copies, seq_len, num_channels) x = x.reshape(B, num_copies * seq_len, num_channels) - x = self.mel_warp(x) + with torch.amp.autocast('cuda', enabled=False): + x = self.mel_warp(x.to(torch.float)) x = x.reshape(B, num_copies, seq_len, num_channels) x = x.permute(1, 0, 2, 3) # (num_copies, B, seq_len, num_channels) From e5d06f8b9e105d45a8d1e2bffebee312dce076a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Jul 2025 18:41:15 +0800 Subject: [PATCH 0423/1191] Remove use of mel_warp. This is basically in principle a rerun of 919, just making sure I didn't break anything. --- egs/librispeech/ASR/zapformer/model.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 75d3c1a2b8..3c119c0ce7 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -26,7 +26,6 @@ from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp -from icefall.exp_augment import MelWarp class AsrModel(nn.Module): @@ -88,15 +87,6 @@ def __init__( self.encoder_embed = encoder_embed self.encoder = encoder - self.mel_warp = MelWarp( - low_freq_hz=20, - high_freq_hz=-400, - sample_rate_hz=16000, - num_mel_bins=80, - p=0.666, - max_shift=1) - - self.use_transducer = use_transducer if use_transducer: # Modules for Transducer head @@ -443,7 +433,6 @@ def forward( x = x.reshape(num_copies, B, seq_len, num_channels) do_time_warp = True - do_mel_warp = True if do_time_warp: # Apply time warping. First append the copies on the channel # dimension so all copies get the exact same time-warping. @@ -459,19 +448,6 @@ def forward( x = x.reshape(B, seq_len, num_copies, num_channels) x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) - if do_mel_warp: - # Apply mel warping. First append the copies on the sequence - # dimension so all copies of the data get the exact same - # mel-warping. (this is done mostly for purposes of the reconstruction - # loss). - x = x.permute(1, 0, 2, 3) # (B, num_copies, seq_len, num_channels) - x = x.reshape(B, num_copies * seq_len, num_channels) - - with torch.amp.autocast('cuda', enabled=False): - x = self.mel_warp(x.to(torch.float)) - x = x.reshape(B, num_copies, seq_len, num_channels) - x = x.permute(1, 0, 2, 3) # (num_copies, B, seq_len, num_channels) - # x_no_specaug is several repeats of the 1st copy of the data, which # is the one not augmented with Musan. But it does have time # warping and mel warping. From 547a855471fdcd7aedb920afa6b6b2ad1bf3f313 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 02:13:02 +0800 Subject: [PATCH 0424/1191] Add cosine similarity loss with limit=0.1 --- egs/librispeech/ASR/zapformer/model.py | 43 +++++++++++++++++++++++++- egs/librispeech/ASR/zapformer/train.py | 5 ++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 3c119c0ce7..7f9da94ddd 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -28,6 +28,44 @@ from icefall.utils import add_sos, make_pad_mask, time_warp +class CosineSimilarityLoss(nn.Module): + def __init__(self, + max_similarity: float = 0.1): + super().__init__() + self.max_similarity = max_similarity + + def forward(self, + x: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + """ + Compute cosine-similarity loss that tries to keep distinct output vectors distinct. + + x: Tensor of shape (..., num_channels) + mask: if supplied, any mask that broadcasts with x[.., 0]. + True means masked positions. + + Returns excess similarity as a sum over frames. + """ + eps = 1.0e-10 + x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() + x = x / x_norm + if mask is not None: + x = x * (~mask).unsqueeze(-1).to(x.dtype) + num_channels = x.shape[-1] + x = x.reshape(-1, num_channels) + n = x.shape[0] + perm = torch.randperm(n, device=x.device) + arange = torch.arange(n, device=x.device) + perm = torch.where(perm != arange, perm, (arange + 1) % n) + #assert torch.all(perm != arange) + + x_permuted = torch.index_select(x, 0, perm) + + similarity = (x * x_permuted).sum(dim=-1) + excess_similarity = (similarity - self.max_similarity).relu() + return excess_similarity + + class AsrModel(nn.Module): def __init__( self, @@ -520,7 +558,10 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.1)( + encoder_out, mask=make_pad_mask(encoder_out_lens)).sum() + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss def forward_reconstruction_loss(self, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 69e070eedc..2f13de1095 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -988,7 +988,7 @@ def compute_loss( spec_augment = None # disable spec-aug with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -1025,6 +1025,8 @@ def warmup_schedule(scale, initial_factor): loss += reconstruction_loss_scale * reconstruction_loss + loss += cosine_similarity_loss + if num_copies > 1: loss += params.predict_loss_scale * predict_loss @@ -1053,6 +1055,7 @@ def warmup_schedule(scale, initial_factor): if num_copies > 1: info["predict_loss"] = predict_loss.detach().cpu().item() info["recon_loss"] = reconstruction_loss.detach().cpu().item() + info["cosine_similarity_loss"] = cosine_similarity_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() From 5fac0a8180fe6fd890f0ff4f03fdcd5055374a33 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Aug 2025 22:31:26 +0800 Subject: [PATCH 0425/1191] Merge torh-compile speed fix. --- egs/librispeech/ASR/zipformer/scaling.py | 62 +++++++++++------------- 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index aa1a963007..d03943d920 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -19,7 +19,7 @@ import math import copy import random -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Any import k2 import torch @@ -1452,22 +1452,23 @@ def swashr_and_deriv(x: Tensor): return y, deriv -swashl_compiled = torch_compile(swashl) -swashr_compiled = torch_compile(swashr) -swashl_and_deriv_compiled = torch_compile(swashl_and_deriv) -swashr_and_deriv_compiled = torch_compile(swashr_and_deriv) - class SwashL(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashl) def forward(self, x: Tensor) -> Tensor: """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 on the input and 0.25 on the output..""" - return swashl_compiled(x) + return self.func(x) class SwashR(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashr) def forward(self, x: Tensor) -> Tensor: """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 on the input and 0.25 on the output..""" - return swashr_compiled(x) + return self.func(x) class SquareLogSoftmax(nn.Module): @@ -1497,7 +1498,8 @@ def forward( x: Tensor, weight: Tensor, bias: Optional[Tensor], - activation: str, + forward_func: Any, + backward_func: Any, dropout_p: float, dropout_shared_dim: Optional[int], ): @@ -1514,16 +1516,9 @@ def forward( ctx.save_for_backward(x, weight, bias, dropout_mask) - ctx.activation = activation + ctx.backward_func = backward_func - forward_activation_dict = { - "SwashL": swashl_compiled, - "SwashR": swashr_compiled, - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) + x = forward_func(x) if dropout_mask is not None: x = x * dropout_mask x = torch.nn.functional.linear(x, weight, bias) @@ -1535,15 +1530,7 @@ def backward(ctx, ans_grad: Tensor): saved = ctx.saved_tensors (x, weight, bias, dropout_mask) = saved - forward_and_deriv_activation_dict = { - "SwashL": swashl_and_deriv_compiled, - "SwashR": swashr_and_deriv_compiled, - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) + y, func_deriv = ctx.backward_func(x) if dropout_mask is not None: y = y * dropout_mask # now compute derivative of y w.r.t. weight and bias.. @@ -1560,7 +1547,7 @@ def backward(ctx, ans_grad: Tensor): # order versus func_deriv does not matter x_deriv = x_deriv * dropout_mask - return x_deriv, weight_deriv, bias_deriv, None, None, None + return x_deriv, weight_deriv, bias_deriv, None, None, None, None @@ -1617,21 +1604,26 @@ def __init__( self.dropout_p = dropout_p self.dropout_shared_dim = dropout_shared_dim + assert activation in ["SwashL", "SwashR"] + if activation == "SwashL": + self.forward_func = torch_compile(swashl) + self.backward_func = torch_compile(swashl_and_deriv) + else: + self.forward_func = torch_compile(swashr) + self.backward_func = torch_compile(swashr_and_deriv) + + def forward(self, x: Tensor): if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == "SwashL": - x = swashl_compiled(x) - elif self.activation == "SwashR": - x = swashr_compiled(x) - else: - assert False, self.activation + x = self.forward_func(x) return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationDropoutAndLinearFunction.apply( x, self.weight, self.bias, - self.activation, + self.forward_func, + self.backward_func, float(self.dropout_p), self.dropout_shared_dim, ) From 6fae208460ad459a26d9172e3e9c755e58b52b28 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Aug 2025 18:07:56 +0800 Subject: [PATCH 0426/1191] Fix for bug in validation mode regarding batch size of 1. --- egs/librispeech/ASR/zapformer/model.py | 32 +++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 7f9da94ddd..37b81117f3 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -252,7 +252,7 @@ def forward_cr_ctc( targets: torch.Tensor, target_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute CTC loss with consistency regularization loss. + """Compute CTC loss, with consistency regularization loss if we are in training mode. Args: encoder_out: Encoder output, of shape (2 * N, T, C). @@ -526,21 +526,21 @@ def forward( if self.use_ctc: targets = y.values - #if not use_cr_ctc: - #ctc_loss = self.forward_ctc( - #encoder_out=encoder_out, - #encoder_out_lens=encoder_out_lens, - #targets=targets, - #target_lengths=y_lens, - #) - #cr_loss = torch.empty(0) - - ctc_loss, cr_loss = self.forward_cr_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + if not self.training: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) else: ctc_loss = torch.empty(0) cr_loss = torch.empty(0) From ce1a6c86223ebe7ca4011ed3c2b9b4381e9c67cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 04:37:04 +0800 Subject: [PATCH 0427/1191] Have the penalty scale for OrthogonalDownsample fall to zero. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8b7bb6b010..f08c37e004 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -948,7 +948,8 @@ def __init__( ): super().__init__() assert proj_dim <= channels * 2 - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. # it will be interpreted by get_parameter_groups_with_lrs() self.proj.lr_scale = 0.75 From 488b1a8c4f77cdc2ad191673b7688a0c2612a7b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 17:38:42 +0800 Subject: [PATCH 0428/1191] Make the cosine similarity loss per sequence --- egs/librispeech/ASR/zapformer/model.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 37b81117f3..443857ea6c 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -40,8 +40,8 @@ def forward(self, """ Compute cosine-similarity loss that tries to keep distinct output vectors distinct. - x: Tensor of shape (..., num_channels) - mask: if supplied, any mask that broadcasts with x[.., 0]. + x: Tensor of shape (batch_size, seq_len, num_channels) + mask: if supplied, mask of shape (batch_size, seq_len); True means masked positions. Returns excess similarity as a sum over frames. @@ -49,17 +49,18 @@ def forward(self, eps = 1.0e-10 x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() x = x / x_norm + (batch_size, seq_len, num_channels) = x.shape + _, permutation = torch.rand(batch_size, seq_len, device=x.device).sort(dim=1) + # permutation: (batch_size, seq_len) + arange = torch.arange(seq_len, device=x.device) + mask2 = (permutation == arange) if mask is not None: - x = x * (~mask).unsqueeze(-1).to(x.dtype) - num_channels = x.shape[-1] - x = x.reshape(-1, num_channels) - n = x.shape[0] - perm = torch.randperm(n, device=x.device) - arange = torch.arange(n, device=x.device) - perm = torch.where(perm != arange, perm, (arange + 1) % n) - #assert torch.all(perm != arange) - - x_permuted = torch.index_select(x, 0, perm) + mask = torch.logical_or(mask, mask2) + else: + mask = mask2 + x = x * (~mask).unsqueeze(-1).to(x.dtype) + + x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) similarity = (x * x_permuted).sum(dim=-1) excess_similarity = (similarity - self.max_similarity).relu() From 3a061224d74124a998791653b6ef664bce473cc5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 04:37:04 +0800 Subject: [PATCH 0429/1191] Reveerse penalty_scale change of 1003->1004 --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f08c37e004..8b7bb6b010 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -948,8 +948,7 @@ def __init__( ): super().__init__() assert proj_dim <= channels * 2 - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, - penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. # it will be interpreted by get_parameter_groups_with_lrs() self.proj.lr_scale = 0.75 From 0040bae0ee8d0f21ba1ee98acca99dd731634213 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 18:35:48 +0800 Subject: [PATCH 0430/1191] Also penalize negative correlations. --- egs/librispeech/ASR/zapformer/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 443857ea6c..708a99d344 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -30,7 +30,7 @@ class CosineSimilarityLoss(nn.Module): def __init__(self, - max_similarity: float = 0.1): + max_similarity: float): # e.g. 0.1 for max_similarity super().__init__() self.max_similarity = max_similarity @@ -62,7 +62,7 @@ def forward(self, x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) - similarity = (x * x_permuted).sum(dim=-1) + similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also excess_similarity = (similarity - self.max_similarity).relu() return excess_similarity From 9e6e3709fe8caa699447270afb0cb8c610afc741 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 18:49:38 +0800 Subject: [PATCH 0431/1191] Decrease max_similarity from 0.1 to 0.05. --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 708a99d344..9f2c40c92d 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -559,7 +559,7 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.1)( + cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05)( encoder_out, mask=make_pad_mask(encoder_out_lens)).sum() return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss From 85f17031e83bf5ecd9de5ec7c1e6ebcca8f6f36e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 19:09:44 +0800 Subject: [PATCH 0432/1191] Have the cosine similarity loss be computed at the end of each zipformer stack and averaged over the stacks. --- egs/librispeech/ASR/zapformer/model.py | 51 ++-------------------- egs/librispeech/ASR/zipformer/scaling.py | 39 +++++++++++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 20 ++++++--- 3 files changed, 56 insertions(+), 54 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 9f2c40c92d..4dd91dbe9d 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -24,49 +24,9 @@ from torch import Tensor from encoder_interface import EncoderInterface from scaling import ScaledLinear, convert_num_channels - from icefall.utils import add_sos, make_pad_mask, time_warp -class CosineSimilarityLoss(nn.Module): - def __init__(self, - max_similarity: float): # e.g. 0.1 for max_similarity - super().__init__() - self.max_similarity = max_similarity - - def forward(self, - x: Tensor, - mask: Optional[Tensor] = None) -> Tensor: - """ - Compute cosine-similarity loss that tries to keep distinct output vectors distinct. - - x: Tensor of shape (batch_size, seq_len, num_channels) - mask: if supplied, mask of shape (batch_size, seq_len); - True means masked positions. - - Returns excess similarity as a sum over frames. - """ - eps = 1.0e-10 - x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() - x = x / x_norm - (batch_size, seq_len, num_channels) = x.shape - _, permutation = torch.rand(batch_size, seq_len, device=x.device).sort(dim=1) - # permutation: (batch_size, seq_len) - arange = torch.arange(seq_len, device=x.device) - mask2 = (permutation == arange) - if mask is not None: - mask = torch.logical_or(mask, mask2) - else: - mask = mask2 - x = x * (~mask).unsqueeze(-1).to(x.dtype) - - x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) - - similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also - excess_similarity = (similarity - self.max_similarity).relu() - return excess_similarity - - class AsrModel(nn.Module): def __init__( self, @@ -199,12 +159,12 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask) + encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, predict_loss + return encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss def forward_ctc( self, @@ -408,7 +368,7 @@ def forward( supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, num_copies: int = 1, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -505,7 +465,7 @@ def forward( # Compute encoder outputs - encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens) + encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss = self.forward_encoder(x, x_lens) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -559,9 +519,6 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05)( - encoder_out, mask=make_pad_mask(encoder_out_lens)).sum() - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d03943d920..7ca3ddcf0a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -848,6 +848,45 @@ def forward(self, x: Tensor): return ans +class CosineSimilarityLoss(nn.Module): + def __init__(self, + max_similarity: float): # e.g. 0.1 for max_similarity + super().__init__() + self.max_similarity = max_similarity + + def forward(self, + x: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + """ + Compute cosine-similarity loss that tries to keep distinct output vectors distinct. + + x: Tensor of shape (batch_size, seq_len, num_channels) + mask: if supplied, mask of shape (batch_size, seq_len); + True means masked positions. + + Returns excess similarity as a sum over frames. + """ + eps = 1.0e-10 + x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() + x = x / x_norm + (batch_size, seq_len, num_channels) = x.shape + _, permutation = torch.rand(batch_size, seq_len, device=x.device).sort(dim=1) + # permutation: (batch_size, seq_len) + arange = torch.arange(seq_len, device=x.device) + mask2 = (permutation == arange) + if mask is not None: + mask = torch.logical_or(mask, mask2) + else: + mask = mask2 + x = x * (~mask).unsqueeze(-1).to(x.dtype) + + x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) + + similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also + excess_similarity = (similarity - self.max_similarity).relu() + return excess_similarity + + class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8b7bb6b010..509b9422eb 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -33,6 +33,7 @@ ActivationDropoutAndLinear, ExpNorm, ChunkCausalDepthwiseConv1d, + CosineSimilarityLoss, Dropout2, FloatLike, ScheduledFloat, @@ -232,7 +233,7 @@ def forward( x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, specaug_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Args: x: @@ -244,12 +245,14 @@ def forward( The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. Returns: - Return a tuple containing 2 tensors: + Return a tuple containing 4 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. + - predict_loss, a cross-prediction loss of randomized codebooks, relying on the CR-CTC + structure of the batch. + - cosine_similarity_loss, a loss that encourages embedding vectors to be independent. """ - chunk_size, left_context_chunks = self.get_chunk_info() if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -266,13 +269,14 @@ def truncate(x, downsampling_factor): predict_loss = 0.0 + cosine_similarity_loss = 0.0 for module in self.encoders: if isinstance(module, Zipformer2Encoder): i = module.encoder_index # was set in this class's __init__ function. ds = self.downsampling_factor[i] x = truncate(x, ds) - x, this_pred_loss = module( + x, this_pred_loss, this_cosine_similarity_loss = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -291,6 +295,7 @@ def truncate(x, downsampling_factor): ), ) predict_loss += this_pred_loss * (ds / self.output_downsampling_factor) + cosine_similarity_loss += this_cosine_similarity_loss * (ds / self.output_downsampling_factor) else: x = module(x) @@ -303,7 +308,8 @@ def truncate(x, downsampling_factor): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths, predict_loss / len(self.downsampling_factor) + L = len(self.downsampling_factor) + return x, lengths, predict_loss / L, cosine_similarity_loss / L def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -757,7 +763,7 @@ def __init__( ) self.predict_loss = PredictLoss(dim) - + self.cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05) def forward( self, @@ -818,7 +824,7 @@ def forward( else: mask = None - return src, self.predict_loss(src, mask) + return src, self.predict_loss(src, mask), self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask).sum() def streaming_forward( self, From 674d6ff12ca76edd7e21aa4b594116182048d433 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 Aug 2025 19:19:50 +0800 Subject: [PATCH 0433/1191] Hardcode cosine_similarity_loss_scale=0.25 (was 1.0) --- egs/librispeech/ASR/zapformer/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2f13de1095..8b2b4d3879 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1025,7 +1025,8 @@ def warmup_schedule(scale, initial_factor): loss += reconstruction_loss_scale * reconstruction_loss - loss += cosine_similarity_loss + cosine_similarity_loss_scale = 0.25 + loss += cosine_similarity_loss * cosine_similarity_loss_scale if num_copies > 1: loss += params.predict_loss_scale * predict_loss From a79476694f46564f733a40a76c252a72af480882 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 22 Aug 2025 01:03:37 +0800 Subject: [PATCH 0434/1191] Fix decode script --- egs/librispeech/ASR/zapformer/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 504d1d94d2..85883ea113 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -452,7 +452,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens, _predict_loss = model.forward_encoder(feature, feature_lens) + encoder_out, encoder_out_lens, _predict_loss, _cosine_loss = model.forward_encoder(feature, feature_lens) hyps = [] From cb4c7171d3648e105b6922a49b09cacb2fc16c94 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Aug 2025 22:51:20 +0800 Subject: [PATCH 0435/1191] Change how upsampling and downsampling is done, materialize only part of the matrix. --- egs/librispeech/ASR/zipformer/scaling.py | 125 +++++++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 185 ++++++++++++--------- 2 files changed, 228 insertions(+), 82 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d03943d920..3cc7538fac 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -848,6 +848,123 @@ def forward(self, x: Tensor): return ans +class SimpleOrthogonalPenaltyFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, weight: Tensor, penalty_scale: float, name: str): + ctx.save_for_backward(weight) + ctx.name = name + ctx.penalty_scale = penalty_scale + return weight + + @staticmethod + @custom_bwd + def backward(ctx, weight_grad): + weight, = ctx.saved_tensors + + if weight.requires_grad and ctx.penalty_scale != 0.0: + penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() + + with torch.enable_grad(): + weight = weight.detach() + weight.requires_grad = True + + # Compute symmetric matrix-product prod with the smallest + # dimension possible given the shape of w. This is not just for + # efficiency; if we computed it the wrong way round, the product + # would have deficient rank and could never be the identity. + if (weight.shape[0] > weight.shape[1]): + prod = torch.matmul(weight.t(), weight) + else: + prod = torch.matmul(weight, weight.t()) + + # we'll try to enforce that for any i, prod[i] is any constant times the identity. + + # in the loss-function: + # orthogonality_loss = ((prod - I) ** 2).sum(), + + # note, prod_diag shares memory with prod, this will matter later on. + (r, c) = prod.shape + (r_stride, c_stride) = prod.stride() + + def diag_inplace(z): + return torch.as_strided(z, size=(r,), stride=(r_stride+c_stride,)) + + diag_inplace(prod)[:] -= 1. + + # that loss that we want to backprop would be 0.5 * (prod ** + # 2).sum() * penalty_scale. we can backprop this without doing + # any reductions as follows: + prod.backward(gradient=prod * penalty_scale) + + + do_print = random.random() < 0.002 + if do_print: + # we print a normalized version of the loss, by dividing by the + # number of rows. + loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] + logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") + + + # add the extra gradient term from the orthogonality loss. + weight_grad = weight_grad + weight.grad + return weight_grad, None, None + +class SimpleOrthogonalLinear(nn.Linear): + """ + Like nn.Linear but can enforce that the weight matrix is orthogonal; in the non-square + case this is interpreted as either M^T M == I or M M^T == I, whichever would give a smaller + dimension. + (If M is square, these definitions are equivalent and is equivalent to the normal + definition of orthogonal). + + Args: + in_channels: number of input channels + out_channels: number of output channels + bias: if True, include a bias term. + penalty_scale: a scale on the penalty on non-orthogonality (this will + be multiplied by the average-absolute-value of the + backpropagated gradient). + """ + # if in_groups or out_groups are set to >1, the orthogonal constraint + # will be set per group. both of them cannot be >1. + def __init__(self, + in_channels: int, + out_channels: int, + in_groups: int = -1, + out_groups: int = -1, + group_size: int = -1, + bias: bool = True, + penalty_scale: FloatLike = 20.0, + ): + super().__init__(in_channels, out_channels, bias=bias) + self.name = None + self.in_groups = in_groups + self.out_groups = out_groups + if in_groups > 0 and group_size == -1: + group_size = in_channels // in_groups + elif out_groups > 0 and group_size == -1: + group_size = out_channels // out_groups + self.group_size = group_size + self.penalty_scale = copy.deepcopy(penalty_scale) + + with torch.no_grad(): + self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) + if self.bias is not None: + torch.nn.init.uniform_(self.bias, -0.01, 0.01) + + + def forward(self, x: Tensor, transpose: bool = False): + # you can only use transpose=True if you used bias=False in initialization + weight = self.weight + if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): + weight = SimpleOrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) + + if transpose: + weight = weight.t() + return torch.nn.functional.linear(x, weight, self.bias) + + class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ @@ -1216,8 +1333,7 @@ def __init__( prob: the probability with which we apply the gradient modification (also affects the grad scale). May be supplied as a float, or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, + grad_scale: determines the scale on the gradient term from this object, relative to the rest of the gradient on the attention weights. E.g. 0.02 (you may want to use smaller values than this if prob is large) """ @@ -1822,6 +1938,10 @@ def _test_orthogonal_linear(): m = OrthogonalLinear(128, 128) m(torch.randn(30, 2, 128)) +def _test_simple_orthogonal_linear(): + m = SimpleOrthogonalLinear(128, 128) + m(torch.randn(30, 2, 128)) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) @@ -1834,3 +1954,4 @@ def _test_orthogonal_linear(): _test_swashl_deriv() _test_activation_dropout_and_linear() _test_orthogonal_linear() + _test_simple_orthogonal_linear() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8b7bb6b010..52573b4923 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -28,6 +28,7 @@ from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. OrthogonalLinear, + SimpleOrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaleLimiter, ActivationDropoutAndLinear, @@ -145,7 +146,6 @@ def _to_tuple(x): encoders = [] num_encoders = len(downsampling_factor) - cur_downsample = 1 # caution: some changes we made for this break the streaming, later we'll try to fix this. encoders_downsampling_factors = [ ] @@ -153,21 +153,8 @@ def _to_tuple(x): # make it so large the limit is never reached. max_proj_dim = max(downsampling_factor) * max(encoder_dim) - def set_downsample_factor(cur_downsample, ds): - while cur_downsample < ds: - # need to downsample - encoders.append(OrthogonalDownsample(channels=input_dim * cur_downsample, - proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) - cur_downsample *= 2 - while cur_downsample > ds: - encoders.append(OrthogonalUpsample(channels=input_dim * cur_downsample, - proj_dim=min(input_dim * cur_downsample, max_proj_dim))) - cur_downsample //= 2 - return cur_downsample for i in range(num_encoders): - cur_downsample = set_downsample_factor(cur_downsample, downsampling_factor[i]) - encoder_layer = Zipformer2EncoderLayer( embed_dim=encoder_dim[i], pos_dim=pos_dim, @@ -186,22 +173,17 @@ def set_downsample_factor(cur_downsample, ds): encoder = Zipformer2Encoder( encoder_layer, num_encoder_layers[i], - dim=cur_downsample*input_dim, + dim=downsampling_factor[i]*input_dim, pos_dim=pos_dim, ) - encoder.encoder_index = i # <-- will be used in streaming_forward encoders.append(encoder) - - cur_downsample = set_downsample_factor(cur_downsample, output_downsampling_factor) - self.encoders = nn.ModuleList(encoders) - def get_chunk_info(self) -> Tuple[int, int]: """ - Returns chunk_size and left_context_chunks. - """ + Returns chunk_size and left_context_chunks. + """ if not self.causal: return -1, -1 @@ -243,14 +225,22 @@ def forward( src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + specaug_mask: + The mask that shows which frames were masked with specaug, of shape (batch_size, seq_len); + True means masked position. May be None. Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ - chunk_size, left_context_chunks = self.get_chunk_info() + orig_seq_len = x.shape[0] + + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, torch.zeros(pad, x.shape[1], x.shape[2], dtype=x.dtype, device=x.device)), + dim=0) if torch.jit.is_scripting() or torch.jit.is_tracing(): # Not support exporting a model for simulating streaming decoding @@ -258,44 +248,42 @@ def forward( else: attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - orig_seq_len = x.shape[0] - - def truncate(x, downsampling_factor): - max_len = (orig_seq_len + downsampling_factor - 1) // downsampling_factor - return x[:max_len] if x.shape[0] > max_len else x - + src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) + specaug_mask = pad_mask(specaug_mask, x.shape[0]) predict_loss = 0.0 - for module in self.encoders: - if isinstance(module, Zipformer2Encoder): - i = module.encoder_index # was set in this class's __init__ function. - ds = self.downsampling_factor[i] - x = truncate(x, ds) - x, this_pred_loss = module( - x, - chunk_size=chunk_size, - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - specaug_mask=( - None - if specaug_mask is None - else specaug_mask[..., ::ds] - ), - attn_mask=(None - if attn_mask is None - else attn_mask[::ds, ::ds] - ), - ) - predict_loss += this_pred_loss * (ds / self.output_downsampling_factor) + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = downsample_by(x, ds) + T = x.shape[0] + x, this_pred_loss = module( + x, + chunk_size=chunk_size, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + specaug_mask=( + None + if specaug_mask is None + else specaug_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + ) + x = upsample_by(x, ds) + predict_loss += this_pred_loss * (ds / self.output_downsampling_factor) - else: - x = module(x) assert self.output_downsampling_factor == 2, self.output_downsampling_factor + od = self.output_downsampling_factor + x = downsample_by(x, od) + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 else: @@ -348,6 +336,7 @@ def _get_attn_mask( logging.info(f"attn_mask = {attn_mask}") return attn_mask + def streaming_forward( self, x: Tensor, @@ -379,21 +368,18 @@ def streaming_forward( layer_offset = 0 for module in enumerate(self.encoders): - if not isinstance(module, Zipformer2Encoder): - x = module(x) - else: - i = module.encoder_index # was set in this class's __init__ function. - num_layers = module.num_layers - ds = self.downsampling_factor[i] - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - new_states += new_layer_states + i = module.encoder_index # was set in this class's __init__ function. + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + new_states += new_layer_states x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. @@ -468,6 +454,40 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: def _balancer_schedule(min_prob: float): return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + +def pad_mask(mask: Optional[Tensor], seq_len: int): + # mask: (batch_size, old_seq_len) + # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). + if mask is None: + return None + (batch_size, old_seq_len) = mask.shape + pad = seq_len - old_seq_len + if pad == 0: + return mask + else: + return torch.cat((mask, torch.ones(bath_size, pad, device=mask.device, dtype=torch.bool)), + dim=1) + + +def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) + return x + +def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + return x + + class Zipformer2EncoderLayer(nn.Module): """ Args: @@ -734,6 +754,10 @@ def __init__( pos_dim: int, ) -> None: super().__init__() + + # self.downsample will also reverse the downsampling operation for us afterward. + self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) + self.encoder_pos = CompactRelPositionalEncoding( pos_dim, dropout_rate=0.0, length_factor=1.0 ) @@ -743,7 +767,6 @@ def __init__( ) self.num_layers = num_layers - self.residual = ResidualModule(encoder_layer.embed_dim) #bypass_dim = dim - encoder_layer.embed_dim @@ -784,13 +807,12 @@ def forward( """ pos_emb = self.encoder_pos(src) - num_channels = src.shape[-1] - layer_dim = self.layers[0].embed_dim - if num_channels > layer_dim: - src, bypass = src[..., :layer_dim], src[..., layer_dim:] + src_orig_fulldim = src + src = self.proj(src) # project to layer dim. src_orig = src + for i, mod in enumerate(self.layers): src = mod( src, @@ -805,9 +827,8 @@ def forward( src = self.residual(src_orig, src) src = self.whiten(src) - if num_channels > layer_dim: - bypass = self.copy_bypass(bypass) - src = torch.cat((src, bypass), dim=-1) + # the following takes care of passing through the "rejected" dimension. + src = src_orig_fulldim + self.proj(src - src_orig, transpose=True) if src_key_padding_mask is not None and specaug_mask is not None: mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) @@ -929,10 +950,14 @@ def forward(self, src_orig: Tensor, src: Tensor): return residual_scale * src_orig + function_scale * src - class OrthogonalDownsample(torch.nn.Module): """ - Does downsampling with an orthogonal matrix, by a factor of two. Projection is initialized + Downsamples on sequence axis by appending sequence-positions together, + and then optionally projects by an orthogonal matrix + + + +. Projection is initialized in a special way and enforced to be orthogonal. Args: From 1ac12c0fe4777b2093f429f5025fd1b8800be7f9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Aug 2025 22:56:15 +0800 Subject: [PATCH 0436/1191] Bug fix --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 52573b4923..c1f3f47b01 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -465,7 +465,7 @@ def pad_mask(mask: Optional[Tensor], seq_len: int): if pad == 0: return mask else: - return torch.cat((mask, torch.ones(bath_size, pad, device=mask.device, dtype=torch.bool)), + return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), dim=1) From 8b9cec3a0bb9441af1b3a40443aff073616c6cee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Aug 2025 23:00:23 +0800 Subject: [PATCH 0437/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3cc7538fac..8956ed541e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -902,8 +902,8 @@ def diag_inplace(z): if do_print: # we print a normalized version of the loss, by dividing by the # number of rows. - loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] - logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") + loss = (prod ** 2).mean() + logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") # add the extra gradient term from the orthogonality loss. From 477041adf94b1f7c01d1a048aa4cd9b3ae48f223 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Aug 2025 23:48:38 +0800 Subject: [PATCH 0438/1191] set proj.lr_scale=0.75 --- egs/librispeech/ASR/zipformer/zipformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index c1f3f47b01..206ede6bea 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -757,6 +757,7 @@ def __init__( # self.downsample will also reverse the downsampling operation for us afterward. self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) + self.proj.lr_scale = 0.75 self.encoder_pos = CompactRelPositionalEncoding( pos_dim, dropout_rate=0.0, length_factor=1.0 From c617a9f37cc40a46b2e689cf8f5f09ca3c6038c4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 00:06:10 +0800 Subject: [PATCH 0439/1191] Bug fix only for streaming. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 206ede6bea..99590123d3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -367,8 +367,7 @@ def streaming_forward( new_states = [] layer_offset = 0 - for module in enumerate(self.encoders): - i = module.encoder_index # was set in this class's __init__ function. + for i, module in enumerate(self.encoders): num_layers = module.num_layers ds = self.downsampling_factor[i] From a565a18da8018d1d44d834f7218f698bc6db9436 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 00:27:32 +0800 Subject: [PATCH 0440/1191] Scale of 0.65 in each Zipformer2Encoder --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 99590123d3..4d072f74f7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -756,7 +756,6 @@ def __init__( # self.downsample will also reverse the downsampling operation for us afterward. self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) - self.proj.lr_scale = 0.75 self.encoder_pos = CompactRelPositionalEncoding( pos_dim, dropout_rate=0.0, length_factor=1.0 @@ -830,6 +829,8 @@ def forward( # the following takes care of passing through the "rejected" dimension. src = src_orig_fulldim + self.proj(src - src_orig, transpose=True) + src = 0.65 * src + if src_key_padding_mask is not None and specaug_mask is not None: mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) elif src_key_padding_mask is not None: From 93627865feff91dc521473fd9bd265befdb426ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 00:29:30 +0800 Subject: [PATCH 0441/1191] Increase max_similarity threshold from .05 to .1 --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 9f2c40c92d..708a99d344 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -559,7 +559,7 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05)( + cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.1)( encoder_out, mask=make_pad_mask(encoder_out_lens)).sum() return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss From 304ecb819bbd3cf8bc6fd402c27525d96572a65b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 00:37:54 +0800 Subject: [PATCH 0442/1191] Remove scale of 0.65 I added 2 commits ago --- egs/librispeech/ASR/zipformer/zipformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4d072f74f7..bfb7338786 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -829,8 +829,6 @@ def forward( # the following takes care of passing through the "rejected" dimension. src = src_orig_fulldim + self.proj(src - src_orig, transpose=True) - src = 0.65 * src - if src_key_padding_mask is not None and specaug_mask is not None: mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) elif src_key_padding_mask is not None: From 1deb2d69efe419e43a42edc8bd85ec2d74416cbf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 00:48:15 +0800 Subject: [PATCH 0443/1191] Revert max_similarity from .1 to .05; pad with repeats of last frame, not zeros. --- egs/librispeech/ASR/zapformer/model.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 708a99d344..9f2c40c92d 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -559,7 +559,7 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.1)( + cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05)( encoder_out, mask=make_pad_mask(encoder_out_lens)).sum() return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bfb7338786..b9f8044413 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -239,7 +239,7 @@ def forward( pad = (-orig_seq_len) % max(self.downsampling_factor) # pad sequence length to be multiple of max(self.downsampling_factor) - x = torch.cat((x, torch.zeros(pad, x.shape[1], x.shape[2], dtype=x.dtype, device=x.device)), + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), dim=0) if torch.jit.is_scripting() or torch.jit.is_tracing(): From bc926485afe2d73262c30784d80fbed69b8bd709 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 12:01:09 +0800 Subject: [PATCH 0444/1191] Set back lr_scale=0.75 on the .proj of ZipformerEncoder. --- egs/librispeech/ASR/zipformer/zipformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b9f8044413..dbcbcce3bd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -756,6 +756,7 @@ def __init__( # self.downsample will also reverse the downsampling operation for us afterward. self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) + self.proj.lr_scale = 0.75 self.encoder_pos = CompactRelPositionalEncoding( pos_dim, dropout_rate=0.0, length_factor=1.0 From 62a227831aaab8782b5f82ab35d3133d76eb041b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 21:04:37 +0800 Subject: [PATCH 0445/1191] Special cosine-based initialization of projections --- egs/librispeech/ASR/zipformer/zipformer.py | 37 ++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index dbcbcce3bd..aef98ed080 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -176,6 +176,9 @@ def _to_tuple(x): dim=downsampling_factor[i]*input_dim, pos_dim=pos_dim, ) + + init_proj_special(input_dim, downsampling_factor[i], encoder.proj.weight) + encoders.append(encoder) self.encoders = nn.ModuleList(encoders) @@ -487,6 +490,40 @@ def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: return x +def get_dct_matrix(N): + """ + Generates an orthonormal DCT-II matrix for a given size N. + Args: + N (int): The size of the square matrix. + Returns: + torch.Tensor: The N x N orthonormal DCT-II matrix. + """ + # Create the base matrix with dimensions (N, N) + mat = torch.zeros(N, N) + # Create a tensor for the indices k (rows) and n (columns) + k = torch.arange(N).unsqueeze(1) + n = torch.arange(N).unsqueeze(0) + # Fill the matrix using the DCT-II formula + mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) + # Adjust the first row (k=0) with a special normalization factor + mat[0] *= (2 ** -0.5) + return mat + +def init_proj_special(input_dim: int, downsampling_factor: int, weight_out: Tensor): + # special initialization of projection weight with orthonormal rows, so that low-freq + (num_rows, num_cols) = weight_out.shape + assert num_cols == input_dim * downsampling_factor and num_rows <= num_cols + weight = torch.eye(num_cols) + d = downsampling_factor + n = input_dim + weight = weight.reshape(d, n, d, n) + dct = get_dct_matrix(d) + weight = torch.matmul(dct, weight.reshape(d, -1)) + weight = weight.reshape(d * n, d * n) + with torch.no_grad(): + weight_out[:] = weight[:num_rows, :] + + class Zipformer2EncoderLayer(nn.Module): """ Args: From 2d25ad0a109664ea26c637f8ddd2ceb0691a9ca8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Aug 2025 22:47:26 +0800 Subject: [PATCH 0446/1191] Increase scheduler warmup_batches from 500 to 1500. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2f13de1095..6b1b1838e4 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1415,7 +1415,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), warmup_batches=1500) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 2a21bc98fead0f13aaf61229d9123e7cb66722ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 12:52:58 +0800 Subject: [PATCH 0447/1191] Have cosine similarity loss be aggregated over sequence before applying the threshold. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7ca3ddcf0a..3b98ea0743 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -883,8 +883,8 @@ def forward(self, x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also - excess_similarity = (similarity - self.max_similarity).relu() - return excess_similarity + excess_similarity = (similarity.sum(dim=1) - seq_len * self.max_similarity).relu() + return excess_similarity.sum() # sum over batch dim. diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 509b9422eb..be3486718b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -824,7 +824,7 @@ def forward( else: mask = None - return src, self.predict_loss(src, mask), self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask).sum() + return src, self.predict_loss(src, mask), self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask) def streaming_forward( self, From fd498e1f162f0f9d778d07688c9faaa7dce6adb0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 13:32:11 +0800 Subject: [PATCH 0448/1191] Have cosine_similarity_loss be applied in backprop, do not return loss value. --- egs/librispeech/ASR/zapformer/model.py | 13 +++++--- egs/librispeech/ASR/zapformer/train.py | 9 +++-- egs/librispeech/ASR/zipformer/zipformer.py | 38 +++++++++++++++++----- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 4dd91dbe9d..7cf2a2781e 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -129,7 +129,7 @@ def __init__( def forward_encoder( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute encoder outputs. Args: @@ -159,12 +159,13 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask) + encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask, + aux_loss_scale=aux_loss_scale) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss + return encoder_out, encoder_out_lens, predict_loss def forward_ctc( self, @@ -368,6 +369,7 @@ def forward( supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, num_copies: int = 1, + aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -465,7 +467,8 @@ def forward( # Compute encoder outputs - encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss = self.forward_encoder(x, x_lens) + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -519,7 +522,7 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss def forward_reconstruction_loss(self, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 8b2b4d3879..25d7c75da4 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -936,6 +936,7 @@ def compute_loss( batch: dict, is_training: bool, spec_augment: Optional[nn.Module] = None, + aux_loss_scale: float = 0.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -988,7 +989,7 @@ def compute_loss( spec_augment = None # disable spec-aug with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -999,6 +1000,7 @@ def compute_loss( supervision_segments=supervision_segments, time_warp_factor=80, # for specaug num_copies=num_copies, + aux_loss_scale=aux_loss_scale, ) loss = 0.0 @@ -1025,9 +1027,6 @@ def warmup_schedule(scale, initial_factor): loss += reconstruction_loss_scale * reconstruction_loss - cosine_similarity_loss_scale = 0.25 - loss += cosine_similarity_loss * cosine_similarity_loss_scale - if num_copies > 1: loss += params.predict_loss_scale * predict_loss @@ -1056,7 +1055,6 @@ def warmup_schedule(scale, initial_factor): if num_copies > 1: info["predict_loss"] = predict_loss.detach().cpu().item() info["recon_loss"] = reconstruction_loss.detach().cpu().item() - info["cosine_similarity_loss"] = cosine_similarity_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -1184,6 +1182,7 @@ def save_bad_model(suffix: str = ""): batch=batch, is_training=True, spec_augment=spec_augment, + aux_loss_scale=scaler._scale.item() if params.use_autocast else 1.0, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index be3486718b..933b023118 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -43,6 +43,7 @@ penalize_abs_values_gt, PredictLoss, softmax, + with_loss, ) from torch import Tensor, nn @@ -233,6 +234,7 @@ def forward( x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, specaug_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Args: @@ -244,6 +246,10 @@ def forward( src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) Returns: Return a tuple containing 4 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) @@ -251,7 +257,6 @@ def forward( of frames in `embeddings` before padding. - predict_loss, a cross-prediction loss of randomized codebooks, relying on the CR-CTC structure of the batch. - - cosine_similarity_loss, a loss that encourages embedding vectors to be independent. """ chunk_size, left_context_chunks = self.get_chunk_info() @@ -268,15 +273,16 @@ def truncate(x, downsampling_factor): return x[:max_len] if x.shape[0] > max_len else x + num_stacks = len(self.downsampling_factor) + predict_loss = 0.0 - cosine_similarity_loss = 0.0 for module in self.encoders: if isinstance(module, Zipformer2Encoder): i = module.encoder_index # was set in this class's __init__ function. ds = self.downsampling_factor[i] x = truncate(x, ds) - x, this_pred_loss, this_cosine_similarity_loss = module( + x, this_pred_loss = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -293,9 +299,9 @@ def truncate(x, downsampling_factor): if attn_mask is None else attn_mask[::ds, ::ds] ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) ) - predict_loss += this_pred_loss * (ds / self.output_downsampling_factor) - cosine_similarity_loss += this_cosine_similarity_loss * (ds / self.output_downsampling_factor) + predict_loss += this_pred_loss * (ds / (self.output_downsampling_factor * num_stacks)) else: x = module(x) @@ -308,8 +314,7 @@ def truncate(x, downsampling_factor): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - L = len(self.downsampling_factor) - return x, lengths, predict_loss / L, cosine_similarity_loss / L + return x, lengths, predict_loss def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -546,6 +551,7 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tensor: """ Pass the input through the encoder layer. @@ -558,6 +564,10 @@ def forward( True means masked position. May be None. src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) Returns: A tensor which has the same shape as src @@ -772,6 +782,7 @@ def forward( attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, specaug_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -795,7 +806,7 @@ def forward( if num_channels > layer_dim: src, bypass = src[..., :layer_dim], src[..., layer_dim:] - + num_layers = len(self.layers) src_orig = src for i, mod in enumerate(self.layers): src = mod( @@ -804,6 +815,7 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, ) # randomize_factor can be viewed as a simple version of an # importance-sampling factor. @@ -824,7 +836,15 @@ def forward( else: mask = None - return src, self.predict_loss(src, mask), self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask) + + # we will apply cosine_similarity_loss during backprop without printing it + # the 0.25 is a heuristic factor specific to cosine similarity loss. + if aux_loss_scale: # if not None and not zero.. + src = with_loss(src, + self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, + name=self.name) + + return src, self.predict_loss(src, mask) def streaming_forward( self, From 8eff73675c24db7b534f54d1c63608a21eba828c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 14:40:32 +0800 Subject: [PATCH 0449/1191] Change how we get scaler scale --- egs/librispeech/ASR/zapformer/train.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 25d7c75da4..2f706a036f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1149,6 +1149,12 @@ def train_one_epoch( saved_bad_model = False + def get_scaler_scale(): + if params.use_autocast and scaler._scale is not None: + return scaler._scale.item() + else: + return 1.0 + def save_bad_model(suffix: str = ""): if params.debug_interval > 0: optimizer.write_debug_info(summary_writer=tb_writer) @@ -1182,7 +1188,7 @@ def save_bad_model(suffix: str = ""): batch=batch, is_training=True, spec_augment=spec_augment, - aux_loss_scale=scaler._scale.item() if params.use_autocast else 1.0, + aux_loss_scale=get_scaler_scale(), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,7 +1244,7 @@ def save_bad_model(suffix: str = ""): ) if params.use_autocast: - cur_grad_scale = scaler._scale.item() + cur_grad_scale = get_scaler_scale() if cur_grad_scale < 0.01: if not saved_bad_model: @@ -1262,7 +1268,7 @@ def save_bad_model(suffix: str = ""): if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + cur_grad_scale = get_scaler_scale() logging.info( f"Epoch {params.cur_epoch}, " From 1e3cd71dd6b914014521aaff6445132303b6c269 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 15:21:59 +0800 Subject: [PATCH 0450/1191] Remove whitening from zipformer layers; code changes that affect printouts. --- egs/librispeech/ASR/zipformer/scaling.py | 11 ++++++++--- egs/librispeech/ASR/zipformer/zipformer.py | 10 +--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3b98ea0743..e4262bc0e2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -850,9 +850,10 @@ def forward(self, x: Tensor): class CosineSimilarityLoss(nn.Module): def __init__(self, - max_similarity: float): # e.g. 0.1 for max_similarity + max_similarity: FloatLike): # e.g. 0.1 for max_similarity super().__init__() self.max_similarity = max_similarity + self.name = None def forward(self, x: Tensor, @@ -864,7 +865,7 @@ def forward(self, mask: if supplied, mask of shape (batch_size, seq_len); True means masked positions. - Returns excess similarity as a sum over frames. + Returns excess similarity as a sum over frames, this should be treated as a loss. """ eps = 1.0e-10 x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() @@ -883,7 +884,11 @@ def forward(self, x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also - excess_similarity = (similarity.sum(dim=1) - seq_len * self.max_similarity).relu() + excess_similarity = (similarity.sum(dim=1) - seq_len * float(self.max_similarity)).relu() + + if random.random() < 0.001: + logging.info("CosineSimilarityLoss: {self.name}, limit={float(self.max_similarity}, excess-similarity={excess_similarity.mean() / seq_len}") + return excess_similarity.sum() # sum over batch dim. diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 933b023118..d6267e3ac9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -765,13 +765,6 @@ def __init__( #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(3.0), - prob=(1, 1), - grad_scale=0.025, - ) - self.predict_loss = PredictLoss(dim) self.cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05) @@ -821,7 +814,6 @@ def forward( # importance-sampling factor. src = self.residual(src_orig, src) - src = self.whiten(src) if num_channels > layer_dim: bypass = self.copy_bypass(bypass) @@ -842,7 +834,7 @@ def forward( if aux_loss_scale: # if not None and not zero.. src = with_loss(src, self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, - name=self.name) + name=None) return src, self.predict_loss(src, mask) From c5089304bb54e5ec808e27f52a2ac47f2c20fdb8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 16:01:06 +0800 Subject: [PATCH 0451/1191] Replace whitening in SelfAttention modules with cosine similarity loss --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 32 ++++++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e4262bc0e2..3b3127fcc1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -887,7 +887,7 @@ def forward(self, excess_similarity = (similarity.sum(dim=1) - seq_len * float(self.max_similarity)).relu() if random.random() < 0.001: - logging.info("CosineSimilarityLoss: {self.name}, limit={float(self.max_similarity}, excess-similarity={excess_similarity.mean() / seq_len}") + logging.info(f"CosineSimilarityLoss: {self.name}, limit={float(self.max_similarity)}, excess-similarity={excess_similarity.mean() / seq_len}") return excess_similarity.sum() # sum over batch dim. diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d6267e3ac9..7e7ded21e9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -584,13 +584,13 @@ def forward( src = src + self.feed_forward1(src) - src = src + self.self_attn1(src, attn_weights) + src = src + self.self_attn1(src, attn_weights, aux_loss_scale=aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) - src = src + self.self_attn2(src, attn_weights) + src = src + self.self_attn2(src, attn_weights, aux_loss_scale=aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) @@ -766,7 +766,7 @@ def __init__( self.copy_bypass = Identity() self.predict_loss = PredictLoss(dim) - self.cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05) + self.cosine_loss = CosineSimilarityLoss(max_similarity=0.05) def forward( self, @@ -829,11 +829,11 @@ def forward( mask = None - # we will apply cosine_similarity_loss during backprop without printing it + # we will apply cosine_loss during backprop without printing it # the 0.25 is a heuristic factor specific to cosine similarity loss. if aux_loss_scale: # if not None and not zero.. src = with_loss(src, - self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, + self.cosine_loss(src.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, name=None) return src, self.predict_loss(src, mask) @@ -1536,20 +1536,17 @@ def __init__( ) f = max(1.0, embed_dim / (num_heads * value_head_dim)) - # the whitening metric cannot be less than f because of the rank imposed - # by the bottleneck. the final whitening limit will be (2.0*3.0) times f, - # i.e. 6 times greater than the mathematical smallest value it can have. - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(f * 2.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) + + + self.cosine_loss = CosineSimilarityLoss(max_similarity=0.25) + def forward( self, x: Tensor, attn_weights: Tensor, + aux_loss_scale: float = 0.0, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Args: @@ -1557,6 +1554,8 @@ def forward( attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect attn_weights.sum(dim=-1) == 1. + src_key_padding_mask: optional Tensor of shape (batch_size, src_seq_len); only + used for the cosine similarity loss, during training. Returns: a tensor with the same shape as x. """ @@ -1581,8 +1580,11 @@ def forward( # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) - x = self.whiten(x) + if aux_loss_scale: + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), + src_key_padding_mask) * aux_loss_scale * 0.25, + name=None) return x def streaming_forward( From a7424310719965ef679101e1adba9fc3adc8156c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 17:43:04 +0800 Subject: [PATCH 0452/1191] Make cosine similarity loss limit increase from .5 to 0.8 (was .25); and do it per head. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7e7ded21e9..43e8bf5d87 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,8 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - - self.cosine_loss = CosineSimilarityLoss(max_similarity=0.25) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.8), default=0.5)) def forward( @@ -1582,7 +1581,7 @@ def forward( x = self.out_proj(x) if aux_loss_scale: - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), + x = with_loss(x, self.cosine_loss(x.reshape(seq_len, batch_size * num_heads, value_head_dim).permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, name=None) return x From 571ed6dd5ad614a792445abea785831f7f59c8ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 17:48:45 +0800 Subject: [PATCH 0453/1191] Reverse the change about heads (made no sense); change max_similarity values to .25->.75 --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 43e8bf5d87..536fed21b1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,7 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.8), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.25), (20000.0, 0.75), default=0.5)) def forward( @@ -1581,7 +1581,7 @@ def forward( x = self.out_proj(x) if aux_loss_scale: - x = with_loss(x, self.cosine_loss(x.reshape(seq_len, batch_size * num_heads, value_head_dim).permute(1, 0, 2), + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, name=None) return x From 8a3ad7aeb6ee94ea9a7c9b2d697d6612993f8a28 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 19:01:14 +0800 Subject: [PATCH 0454/1191] Reduce aux_loss_scale to self_attn by factor of 10. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 536fed21b1..9133d8a02d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -584,13 +584,13 @@ def forward( src = src + self.feed_forward1(src) - src = src + self.self_attn1(src, attn_weights, aux_loss_scale=aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.self_attn1(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) - src = src + self.self_attn2(src, attn_weights, aux_loss_scale=aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.self_attn2(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) From c7df91d7933528020008d9cc93066b6fb00ed85b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 19:35:57 +0800 Subject: [PATCH 0455/1191] Increase initial max_similarity of self-attn from .25 to .5. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9133d8a02d..2b3a9b98c5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,7 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.25), (20000.0, 0.75), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) def forward( From f6270236dee4705ccb13c0141e6d0967e73aa7b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 19:38:37 +0800 Subject: [PATCH 0456/1191] Increase cosine_loss limits and make it per head. --- egs/librispeech/ASR/zipformer/zipformer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2b3a9b98c5..58bad12452 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,7 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.75), (20000.0, 0.9), default=0.5)) def forward( @@ -1569,7 +1569,13 @@ def forward( # todo: see whether there is benefit in overriding matmul x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) + # x: (num_heads, batch_size, seq_len, value_head_dim) + + if aux_loss_scale: + x = with_loss(x, self.cosine_loss(x.reshape(num_heads * batch_size, seq_len, value_head_dim), + src_key_padding_mask.repeat(num_heads, 1) if src_key_padding_mask is not None else None + ) * aux_loss_scale * 0.25, + name=None) x = ( x.permute(2, 1, 0, 3) @@ -1580,10 +1586,6 @@ def forward( # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) - if aux_loss_scale: - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), - src_key_padding_mask) * aux_loss_scale * 0.25, - name=None) return x def streaming_forward( From 5778b4259119c919eb151b148faf203fed1f2a95 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 20:24:28 +0800 Subject: [PATCH 0457/1191] Multiply aux_loss_scale by 1/heads to correct for larger batch size than real batch size --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 58bad12452..5ceb347d03 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1574,7 +1574,7 @@ def forward( if aux_loss_scale: x = with_loss(x, self.cosine_loss(x.reshape(num_heads * batch_size, seq_len, value_head_dim), src_key_padding_mask.repeat(num_heads, 1) if src_key_padding_mask is not None else None - ) * aux_loss_scale * 0.25, + ) * aux_loss_scale * 0.25 * (1. / num_heads), name=None) x = ( From 1ba17fc365a7a113791d3ca4044c0b8d566aa4c4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Aug 2025 23:53:26 +0800 Subject: [PATCH 0458/1191] Fix decode script --- egs/librispeech/ASR/zapformer/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 85883ea113..504d1d94d2 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -452,7 +452,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens, _predict_loss, _cosine_loss = model.forward_encoder(feature, feature_lens) + encoder_out, encoder_out_lens, _predict_loss = model.forward_encoder(feature, feature_lens) hyps = [] From d29d4c5625b01d95a02694b3689ff35a90324bd5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 00:34:11 +0800 Subject: [PATCH 0459/1191] Merge deterministic_invertible1032conv --- egs/librispeech/ASR/zapformer/decode.py | 2 +- egs/librispeech/ASR/zapformer/model.py | 13 ++++--- egs/librispeech/ASR/zapformer/train.py | 18 ++++++--- egs/librispeech/ASR/zipformer/scaling.py | 13 +++++-- egs/librispeech/ASR/zipformer/zipformer.py | 45 ++++++++++++++-------- 5 files changed, 59 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 85883ea113..504d1d94d2 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -452,7 +452,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens, _predict_loss, _cosine_loss = model.forward_encoder(feature, feature_lens) + encoder_out, encoder_out_lens, _predict_loss = model.forward_encoder(feature, feature_lens) hyps = [] diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 4dd91dbe9d..7cf2a2781e 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -129,7 +129,7 @@ def __init__( def forward_encoder( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute encoder outputs. Args: @@ -159,12 +159,13 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask) + encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask, + aux_loss_scale=aux_loss_scale) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss + return encoder_out, encoder_out_lens, predict_loss def forward_ctc( self, @@ -368,6 +369,7 @@ def forward( supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, num_copies: int = 1, + aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -465,7 +467,8 @@ def forward( # Compute encoder outputs - encoder_out, encoder_out_lens, predict_loss, cosine_similarity_loss = self.forward_encoder(x, x_lens) + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -519,7 +522,7 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, encoder_out_lens) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss def forward_reconstruction_loss(self, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6b1b1838e4..137972d752 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -936,6 +936,7 @@ def compute_loss( batch: dict, is_training: bool, spec_augment: Optional[nn.Module] = None, + aux_loss_scale: float = 0.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -988,7 +989,7 @@ def compute_loss( spec_augment = None # disable spec-aug with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss, cosine_similarity_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -999,6 +1000,7 @@ def compute_loss( supervision_segments=supervision_segments, time_warp_factor=80, # for specaug num_copies=num_copies, + aux_loss_scale=aux_loss_scale, ) loss = 0.0 @@ -1025,8 +1027,6 @@ def warmup_schedule(scale, initial_factor): loss += reconstruction_loss_scale * reconstruction_loss - loss += cosine_similarity_loss - if num_copies > 1: loss += params.predict_loss_scale * predict_loss @@ -1055,7 +1055,6 @@ def warmup_schedule(scale, initial_factor): if num_copies > 1: info["predict_loss"] = predict_loss.detach().cpu().item() info["recon_loss"] = reconstruction_loss.detach().cpu().item() - info["cosine_similarity_loss"] = cosine_similarity_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -1150,6 +1149,12 @@ def train_one_epoch( saved_bad_model = False + def get_scaler_scale(): + if params.use_autocast and scaler._scale is not None: + return scaler._scale.item() + else: + return 1.0 + def save_bad_model(suffix: str = ""): if params.debug_interval > 0: optimizer.write_debug_info(summary_writer=tb_writer) @@ -1183,6 +1188,7 @@ def save_bad_model(suffix: str = ""): batch=batch, is_training=True, spec_augment=spec_augment, + aux_loss_scale=get_scaler_scale(), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,7 +1244,7 @@ def save_bad_model(suffix: str = ""): ) if params.use_autocast: - cur_grad_scale = scaler._scale.item() + cur_grad_scale = get_scaler_scale() if cur_grad_scale < 0.01: if not saved_bad_model: @@ -1262,7 +1268,7 @@ def save_bad_model(suffix: str = ""): if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + cur_grad_scale = get_scaler_scale() logging.info( f"Epoch {params.cur_epoch}, " diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 59fb3afb52..e1c25ca767 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -966,9 +966,10 @@ def forward(self, x: Tensor, transpose: bool = False): class CosineSimilarityLoss(nn.Module): def __init__(self, - max_similarity: float): # e.g. 0.1 for max_similarity + max_similarity: FloatLike): # e.g. 0.1 for max_similarity super().__init__() self.max_similarity = max_similarity + self.name = None def forward(self, x: Tensor, @@ -980,7 +981,7 @@ def forward(self, mask: if supplied, mask of shape (batch_size, seq_len); True means masked positions. - Returns excess similarity as a sum over frames. + Returns excess similarity as a sum over frames, this should be treated as a loss. """ eps = 1.0e-10 x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() @@ -999,8 +1000,12 @@ def forward(self, x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also - excess_similarity = (similarity - self.max_similarity).relu() - return excess_similarity + excess_similarity = (similarity.sum(dim=1) - seq_len * float(self.max_similarity)).relu() + + if random.random() < 0.001: + logging.info("CosineSimilarityLoss: {self.name}, limit={float(self.max_similarity}, excess-similarity={excess_similarity.mean() / seq_len}") + + return excess_similarity.sum() # sum over batch dim. class ChunkCausalDepthwiseConv1d(torch.nn.Module): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 21e6c3d831..bffc02f0ca 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -44,6 +44,7 @@ penalize_abs_values_gt, PredictLoss, softmax, + with_loss, ) from torch import Tensor, nn @@ -218,6 +219,7 @@ def forward( x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, specaug_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Args: @@ -232,6 +234,10 @@ def forward( specaug_mask: The mask that shows which frames were masked with specaug, of shape (batch_size, seq_len); True means masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) Returns: Return a tuple containing 4 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) @@ -239,7 +245,6 @@ def forward( of frames in `embeddings` before padding. - predict_loss, a cross-prediction loss of randomized codebooks, relying on the CR-CTC structure of the batch. - - cosine_similarity_loss, a loss that encourages embedding vectors to be independent. """ chunk_size, left_context_chunks = self.get_chunk_info() orig_seq_len = x.shape[0] @@ -258,14 +263,15 @@ def forward( src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) specaug_mask = pad_mask(specaug_mask, x.shape[0]) + num_stacks = len(self.downsampling_factor) + predict_loss = 0.0 - cosine_similarity_loss = 0.0 for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = downsample_by(x, ds) T = x.shape[0] - x, this_pred_loss, this_cosine_similarity_loss = module( + x, this_pred_loss = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -282,10 +288,10 @@ def forward( if attn_mask is None else attn_mask[::ds, ::ds] ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) ) x = upsample_by(x, ds) - predict_loss += this_pred_loss * (ds / self.output_downsampling_factor) - cosine_similarity_loss += this_cosine_similarity_loss * (ds / self.output_downsampling_factor) + predict_loss += this_pred_loss * (ds / (self.output_downsampling_factor * num_stacks)) assert self.output_downsampling_factor == 2, self.output_downsampling_factor @@ -300,8 +306,7 @@ def forward( warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - L = len(self.downsampling_factor) - return x, lengths, predict_loss / L, cosine_similarity_loss / L + return x, lengths, predict_loss def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -603,6 +608,7 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tensor: """ Pass the input through the encoder layer. @@ -615,6 +621,10 @@ def forward( True means masked position. May be None. src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) Returns: A tensor which has the same shape as src @@ -816,13 +826,6 @@ def __init__( #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(3.0), - prob=(1, 1), - grad_scale=0.025, - ) - self.predict_loss = PredictLoss(dim) self.cosine_similarity_loss = CosineSimilarityLoss(max_similarity=0.05) @@ -833,6 +836,7 @@ def forward( attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, specaug_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -855,6 +859,7 @@ def forward( src = self.proj(src) # project to layer dim. + num_layers = len(self.layers) src_orig = src for i, mod in enumerate(self.layers): @@ -864,12 +869,12 @@ def forward( chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, ) # randomize_factor can be viewed as a simple version of an # importance-sampling factor. src = self.residual(src_orig, src) - src = self.whiten(src) # the following takes care of passing through the "rejected" dimension. src = src_orig_fulldim + self.proj(src - src_orig, transpose=True) @@ -883,7 +888,15 @@ def forward( else: mask = None - return src, self.predict_loss(src, mask), self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask).sum() + + # we will apply cosine_similarity_loss during backprop without printing it + # the 0.25 is a heuristic factor specific to cosine similarity loss. + if aux_loss_scale: # if not None and not zero.. + src = with_loss(src, + self.cosine_similarity_loss(src.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, + name=None) + + return src, self.predict_loss(src, mask) def streaming_forward( self, From 9c0d9c005586bcd92eba762d1823881a4ac122ec Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 00:35:18 +0800 Subject: [PATCH 0460/1191] remove init_proj_special (cosine-based initialization of projections --- egs/librispeech/ASR/zipformer/zipformer.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bffc02f0ca..4313022e5a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -179,8 +179,6 @@ def _to_tuple(x): pos_dim=pos_dim, ) - init_proj_special(input_dim, downsampling_factor[i], encoder.proj.weight) - encoders.append(encoder) self.encoders = nn.ModuleList(encoders) @@ -521,20 +519,6 @@ def get_dct_matrix(N): mat[0] *= (2 ** -0.5) return mat -def init_proj_special(input_dim: int, downsampling_factor: int, weight_out: Tensor): - # special initialization of projection weight with orthonormal rows, so that low-freq - (num_rows, num_cols) = weight_out.shape - assert num_cols == input_dim * downsampling_factor and num_rows <= num_cols - weight = torch.eye(num_cols) - d = downsampling_factor - n = input_dim - weight = weight.reshape(d, n, d, n) - dct = get_dct_matrix(d) - weight = torch.matmul(dct, weight.reshape(d, -1)) - weight = weight.reshape(d * n, d * n) - with torch.no_grad(): - weight_out[:] = weight[:num_rows, :] - class Zipformer2EncoderLayer(nn.Module): """ From b89474fd90801beb3f9ad7cb2d2dd22fcf1d0c6f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 17:34:51 +0800 Subject: [PATCH 0461/1191] Double attn_score_limit schedule, final now 40 not 20. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 5ceb347d03..a016884411 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1195,7 +1195,7 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) + self.attn_score_limit = ScheduledFloat((0.0, 10.0), (5000.0, 40.0)) self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) key_head_dim = query_head_dim From c2ee356fc951befa7dbe91326e9bd46b2e69c691 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 18:29:06 +0800 Subject: [PATCH 0462/1191] Decrease max_similarity max from .9 to .75. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a016884411..2d646d2944 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,7 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.75), (20000.0, 0.9), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) def forward( From 234339c7fb7c257a923171d4fdffaecfe0273a87 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 19:41:54 +0800 Subject: [PATCH 0463/1191] Move aux_loss_scale to output of SelfAttention and decrease final limit to .5 --- egs/librispeech/ASR/zipformer/zipformer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2d646d2944..8801d0e6fc 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,7 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.25), (20000.0, 0.5), default=0.5)) def forward( @@ -1571,12 +1571,6 @@ def forward( x = torch.matmul(attn_weights, x) # x: (num_heads, batch_size, seq_len, value_head_dim) - if aux_loss_scale: - x = with_loss(x, self.cosine_loss(x.reshape(num_heads * batch_size, seq_len, value_head_dim), - src_key_padding_mask.repeat(num_heads, 1) if src_key_padding_mask is not None else None - ) * aux_loss_scale * 0.25 * (1. / num_heads), - name=None) - x = ( x.permute(2, 1, 0, 3) .contiguous() @@ -1586,6 +1580,11 @@ def forward( # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) + if aux_loss_scale: + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), + src_key_padding_mask) * aux_loss_scale * 0.25, + name=None) + return x def streaming_forward( From 844a074e2a08a1e9308b9b4be016b2535df4c55d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 19:45:12 +0800 Subject: [PATCH 0464/1191] Move aux_loss_scale to output of SelfAttention and decrease final limit to .5, and incorporate mem-eff changes from 1034. --- egs/librispeech/ASR/zipformer/scaling.py | 95 ++++++++++++++++------ egs/librispeech/ASR/zipformer/zipformer.py | 6 +- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3b3127fcc1..b7a7ce8d64 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -848,6 +848,58 @@ def forward(self, x: Tensor): return ans +class CosineSimilarityLossFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, mask: Optional[Tensor], max_similarity: float, weight: float, name: str): + ctx.save_for_backward(x) + ctx.mask = mask # mask will have no grad so it should be OK to store this way + ctx.name = name + ctx.weight = weight + ctx.max_similarity = max_similarity + return torch.tensor(0.0, device=x.device, dtype=x.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + x, = ctx.saved_tensors + mask = ctx.mask # optional Tensor + name = ctx.name # str + weight = ctx.weight # float + max_similarity = ctx.max_similarity # float + + + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + + eps = 3.0e-08 # won't be zero in float16 + x_norm = x / ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() + (batch_size, seq_len, num_channels) = x.shape + _, permutation = torch.rand(batch_size, seq_len, device=x.device).sort(dim=1) + # permutation: (batch_size, seq_len) + arange = torch.arange(seq_len, device=x.device) + mask2 = (permutation == arange) + if mask is not None: + mask = torch.logical_or(mask, mask2) + else: + mask = mask2 + x_norm = x_norm * (~mask).unsqueeze(-1).to(x.dtype) + + x_permuted = torch.gather(x_norm, 1, permutation.unsqueeze(-1).expand(*x.shape)) + + similarity = (x_norm * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also + excess_similarity = (similarity.sum(dim=1) - seq_len * max_similarity).relu() + + if random.random() < 0.001: + logging.info(f"CosineSimilarityLoss: {name}, limit={max_similarity}, excess-similarity={excess_similarity.mean() / seq_len}") + + grad = (weight * ans_grad).expand(excess_similarity.numel()) + excess_similarity.backward(grad) + + return x.grad, None, None, None, None + + class CosineSimilarityLoss(nn.Module): def __init__(self, max_similarity: FloatLike): # e.g. 0.1 for max_similarity @@ -857,39 +909,30 @@ def __init__(self, def forward(self, x: Tensor, + loss_scale: float, mask: Optional[Tensor] = None) -> Tensor: """ Compute cosine-similarity loss that tries to keep distinct output vectors distinct. - x: Tensor of shape (batch_size, seq_len, num_channels) - mask: if supplied, mask of shape (batch_size, seq_len); + x: Tensor of shape (batch_size, seq_len, num_channels) + loss_scale: the scale with which the loss should be incorporated into the graph. + This should contain a factor of the grad_scale, if you are using GradScaler for + automatic mixed precision training (amp). + The loss will be summed over frames, and multiplied by this value. + mask: if supplied, mask of shape (batch_size, seq_len); True means masked positions. - Returns excess similarity as a sum over frames, this should be treated as a loss. + Returns: + returns a scaled scalar loss value "ret" which should be incorporated + into the backprop graph by doing: + z = with_loss(z, ret, None) + where z is any quantity that will be used in calculating the main loss. + Ret will always be numerically equal to zero in the forward pass but + may behave as if it were nonzero for backprop purposes. """ - eps = 1.0e-10 - x_norm = ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() - x = x / x_norm - (batch_size, seq_len, num_channels) = x.shape - _, permutation = torch.rand(batch_size, seq_len, device=x.device).sort(dim=1) - # permutation: (batch_size, seq_len) - arange = torch.arange(seq_len, device=x.device) - mask2 = (permutation == arange) - if mask is not None: - mask = torch.logical_or(mask, mask2) - else: - mask = mask2 - x = x * (~mask).unsqueeze(-1).to(x.dtype) - - x_permuted = torch.gather(x, 1, permutation.unsqueeze(-1).expand(*x.shape)) - - similarity = (x * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also - excess_similarity = (similarity.sum(dim=1) - seq_len * float(self.max_similarity)).relu() - - if random.random() < 0.001: - logging.info(f"CosineSimilarityLoss: {self.name}, limit={float(self.max_similarity)}, excess-similarity={excess_similarity.mean() / seq_len}") - - return excess_similarity.sum() # sum over batch dim. + return CosineSimilarityLossFunction.apply(x, mask, + float(self.max_similarity), + loss_scale, self.name) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8801d0e6fc..43f09b13f6 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -833,7 +833,7 @@ def forward( # the 0.25 is a heuristic factor specific to cosine similarity loss. if aux_loss_scale: # if not None and not zero.. src = with_loss(src, - self.cosine_loss(src.permute(1, 0, 2), src_key_padding_mask) * aux_loss_scale * 0.25, + self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale * 0.25, src_key_padding_mask), name=None) return src, self.predict_loss(src, mask) @@ -1582,8 +1582,8 @@ def forward( if aux_loss_scale: x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), - src_key_padding_mask) * aux_loss_scale * 0.25, - name=None) + aux_loss_scale * 0.25, + mask=src_key_padding_mask), None) return x From 09851d5e80c8d6e54c849b1608b1b6cb38872dc1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 22:07:24 +0800 Subject: [PATCH 0465/1191] Increase similarity limit of SelfAttention to 0.75; restore attn_score_limit to old values --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 43f09b13f6..f1dff0f36e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1537,7 +1537,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.25), (20000.0, 0.5), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) def forward( From 942f88717819328a7f82f9ac7d0879cee6346e0e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 22:11:48 +0800 Subject: [PATCH 0466/1191] revert attn_score_limit to old value. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f1dff0f36e..d2bdde3dc4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1195,7 +1195,7 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.attn_score_limit = ScheduledFloat((0.0, 10.0), (5000.0, 40.0)) + self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) key_head_dim = query_head_dim From ef688af68a7ed4931f123bcb43af98a86e903c9b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 27 Aug 2025 22:22:06 +0800 Subject: [PATCH 0467/1191] Increase initial value of attn_score_limit from 5 to 10 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d2bdde3dc4..0fa115219b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1195,7 +1195,7 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) + self.attn_score_limit = ScheduledFloat((0.0, 10.0), (5000.0, 20.0)) self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) key_head_dim = query_head_dim From 57f22860e8feca0aa05215ea788b7581705fa4ed Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 28 Aug 2025 17:43:52 +0800 Subject: [PATCH 0468/1191] Remove out_whiten from feedforward modules. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0fa115219b..be6ec0feca 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1666,17 +1666,10 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): initial_scale=0.5, ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) def forward(self, x: Tensor): x = self.in_proj(x) x = self.out_proj(x) - x = self.out_whiten(x) return x From 2cdd8295140e6f751fc60235d072c646e681c31a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 28 Aug 2025 18:02:53 +0800 Subject: [PATCH 0469/1191] Remove whitening of ConvolutionMOdule --- egs/librispeech/ASR/zipformer/zipformer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index be6ec0feca..56f66ac2b4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1878,13 +1878,6 @@ def __init__( ) ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, @@ -1942,7 +1935,6 @@ def forward( x = x.permute(2, 0, 1) # (time, batch, channels) - x = self.whiten(x) # (time, batch, channels) x = self.out_proj(x) # (time, batch, channels) return x From b0636ff95066d5b737244fd72006542bab825b56 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 28 Aug 2025 20:02:58 +0800 Subject: [PATCH 0470/1191] Introduce cosine loss on feedforward modules (max=0.2) to replacd previously removed whiten module. --- egs/librispeech/ASR/zipformer/zipformer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index be6ec0feca..2749a1593a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -582,19 +582,19 @@ def forward( key_padding_mask=src_key_padding_mask, ) - src = src + self.feed_forward1(src) + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn1(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward2(src) + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn2(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward3(src) + src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = self.residual(src_orig, src) @@ -1666,10 +1666,13 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): initial_scale=0.5, ) + self.cosine_loss = CosineSimilarityLoss(max_similarity=0.2) + - def forward(self, x: Tensor): + def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: x = self.in_proj(x) x = self.out_proj(x) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale * 0.25, src_key_padding_mask), None) return x From 9302e6fd5fc60275fdc2ef41bb4e31a74e549207 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 28 Aug 2025 22:20:21 +0800 Subject: [PATCH 0471/1191] Introduce cosine_loss (max=0.2) on output of conv modules --- egs/librispeech/ASR/zipformer/zipformer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 172029696b..47020a1171 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -586,13 +586,13 @@ def forward( src = src + self.self_attn1(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn2(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -1888,12 +1888,15 @@ def __init__( dropout_p=0.0, initial_scale=0.05, ) + self.cosine_loss = CosineSimilarityLoss(max_similarity=0.2) + def forward( self, x: Tensor, src_key_padding_mask: Optional[Tensor] = None, chunk_size: int = -1, + aux_loss_scale: float = 0.0, ) -> Tensor: """Compute convolution module. @@ -1940,6 +1943,9 @@ def forward( x = self.out_proj(x) # (time, batch, channels) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), + None) + return x def streaming_forward( From ae015d98eac215b33a38e2178dd7be99d2ce9f4b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 29 Aug 2025 11:07:48 +0800 Subject: [PATCH 0472/1191] Make max_similarity a rank-dependent formula --- egs/librispeech/ASR/zipformer/zipformer.py | 32 +++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 47020a1171..6353c35619 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -472,6 +472,31 @@ def get_init_states( return states +def get_max_similarity(rank: int, power: float): + """ + This returns a value for the "max_similarity" argument of CosineSimilarityLoss. + the max_similarity is an upper limit we impose on the mean value of (x_i . x_j) + if i != j are two different sequence-position indexes and x_i and x_j are + activation vectors normalized to have unit length. + + rank: the dimension of the space, usually this is the num_channels, but if + we have just up-projected from a bottleneck, it would be the bottleneck + dimension. + power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean + we enforce the vector dimensions to be completely independent like Gaussian noise + (don't do this); if we set power=0.0 it would be equivalent to not having + the CosineSimilarityLoss at all. + + The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal + variable. If x consists of independent Gaussian noise of dimension D, with + variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" + would be close to a no-op for large D), then (x_i . x_j) would be distributed as + a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) + would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value + between this and 1, as a kind of heuristic limit on this max_similarity. + """ + return (0.7978845608 / (rank ** 0.5)) ** power + def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) @@ -741,7 +766,6 @@ class Zipformer2Encoder(nn.Module): """ - def __init__( self, encoder_layer: nn.Module, @@ -766,7 +790,7 @@ def __init__( self.copy_bypass = Identity() self.predict_loss = PredictLoss(dim) - self.cosine_loss = CosineSimilarityLoss(max_similarity=0.05) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) def forward( self, @@ -1666,7 +1690,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): initial_scale=0.5, ) - self.cosine_loss = CosineSimilarityLoss(max_similarity=0.2) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.5)) def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: @@ -1888,7 +1912,7 @@ def __init__( dropout_p=0.0, initial_scale=0.05, ) - self.cosine_loss = CosineSimilarityLoss(max_similarity=0.2) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.5)) def forward( From c2f8c5d53ed245b11f7d5d46bcd030c62ce6faeb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 12:04:03 +0800 Subject: [PATCH 0473/1191] Introduce ProductLoss to try to keep the big dimensions going through the non-residual term. --- egs/librispeech/ASR/zipformer/scaling.py | 91 ++++++++++++++++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 51 +++++++++--- 2 files changed, 130 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ba189a9086..6758a2ed3d 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1051,6 +1051,97 @@ def forward(self, loss_scale, self.name) +class MinProductLossFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, y: Tensor, mask: Optional[Tensor], min_product: float, weight: float, name: str): + ctx.save_for_backward(x, y) + ctx.mask = mask # mask will have no grad so it should be OK to store this way + ctx.name = name + ctx.weight = weight + ctx.min_product = min_product + return torch.tensor(0.0, device=x.device, dtype=x.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + x, y = ctx.saved_tensors + mask = ctx.mask # optional Tensor + name = ctx.name # str + weight = ctx.weight # float + min_product = ctx.min_product # float + + + with torch.enable_grad(): + x, y = x.detach(), y.detach() + x.requires_grad = True + y.requires_grad = True + + eps = 3.0e-08 # won't be zero in float16 + x_norm = x / ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() + y_norm = y / ((y ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() + (batch_size, seq_len, num_channels) = x.shape + + + + product = x_norm * y_norm + product = product.sum(dim=-1) + if mask is not None: + inv_mask = (~mask).to(x.dtype) + product = product * inv_mask + + if mask is not None: + product_deficit = (inv_mask.sum(dim=1) * min_product - product.sum(dim=1)).relu() + else: + product_deficit = (seq_len * min_product - product.sum(dim=1)).relu() + + if random.random() < 0.0005: + logging.info(f"MinProductLoss: {name}, limit={min_product}, product-deficit={product_deficit.mean() / seq_len}") + + grad = (weight * ans_grad).expand(product_deficit.numel()) + product_deficit.backward(grad) + + return x.grad, y.grad, None, None, None, None + +class MinProductLoss(nn.Module): + def __init__(self, + min_product: FloatLike): # e.g. 0.5 for min_product + super().__init__() + self.min_product = min_product + self.name = None + + def forward(self, + x: Tensor, + y: Tensor, + loss_scale: float, + mask: Optional[Tensor] = None) -> Tensor: + """ + Compute loss that tries to keep two embeddings in similar directions, used to + make sure that the bulk of the embedding goes through one branch. + + x: Tensor of shape (batch_size, seq_len, num_channels) + y: Tensor of shape (batch_size, seq_len, num_channels) + loss_scale: the scale with which the loss should be incorporated into the graph. + This should contain a factor of the grad_scale, if you are using GradScaler for + automatic mixed precision training (amp). + The loss will be summed over frames, and multiplied by this value. + mask: if supplied, mask of shape (batch_size, seq_len); + True means masked positions that will be ignored. + + Returns: + returns a scaled scalar loss value "ret" which should be incorporated + into the backprop graph by doing: + z = with_loss(z, ret, None) + where z is any quantity that will be used in calculating the main loss. + Ret will always be numerically equal to zero in the forward pass but + may behave as if it were nonzero for backprop purposes. + """ + return MinProductLossFunction.apply(x, y, mask, + float(self.min_product), + loss_scale, self.name) + + + class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ Behaves like a depthwise 1d convolution, except that it is causal in diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fdcfe0d455..d5f6a27d99 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -35,6 +35,7 @@ ExpNorm, ChunkCausalDepthwiseConv1d, CosineSimilarityLoss, + MinProductLoss, Dropout2, FloatLike, ScheduledFloat, @@ -831,12 +832,16 @@ def __init__( ) self.num_layers = num_layers - self.residual = ResidualModule(encoder_layer.embed_dim) + self.residual_scale = nn.Parameter(0.5 * torch.zeros(encoder_layer.embed_dim)) #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() self.predict_loss = PredictLoss(dim) + + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + self.min_product_loss = MinProductLoss(0.5) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) def forward( @@ -884,10 +889,10 @@ def forward( # randomize_factor can be viewed as a simple version of an # importance-sampling factor. - src = self.residual(src_orig, src) - - # the following takes care of passing through the "rejected" dimension. - src = src_orig_fulldim + self.proj(src - src_orig, transpose=True) + src = self.add_residual(src_orig_fulldim, src_orig, src, aux_loss_scale, src_key_padding_mask) + # The above is equivalent to: + # src = src_orig_fulldim + self.proj((src - src_orig) * self.residual_scale, transpose=True) + # .. but with extra losses. if src_key_padding_mask is not None and specaug_mask is not None: mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) @@ -898,15 +903,37 @@ def forward( else: mask = None + return src, self.predict_loss(src, mask) - # we will apply cosine_loss during backprop without printing it - # the 0.25 is a heuristic factor specific to cosine similarity loss. - if aux_loss_scale: # if not None and not zero.. - src = with_loss(src, - self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale * 0.25, src_key_padding_mask), - name=None) - return src, self.predict_loss(src, mask) + def add_residual( + self, + src_orig_fulldim, + src_orig, + src, + aux_loss_scale: float, + src_key_padding_mask: Optional[Tensor]): + residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0, training=self.training) + offset = (src - src_orig) * residual_scale + if aux_loss_scale: + offset = with_loss(offset, + self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale * 0.25, src_key_padding_mask), + None) + + offset = self.proj(offset, transpose=True) + tot = src_orig_fulldim + offset + + if aux_loss_scale: + tot_permuted = tot.permute(1, 0, 2) + tot = with_loss(tot, + self.cosine_loss(tot_permuted, + aux_loss_scale * 0.25, src_key_padding_mask) + + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), + None) + + return tot + def streaming_forward( self, From 2f6aa5bfde544940c08dfe6a5ba55f4a39f30d20 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 31 Aug 2025 16:12:36 +0800 Subject: [PATCH 0474/1191] Change to torch.compile to try to pad shapes. --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6758a2ed3d..30529cb353 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1709,9 +1709,10 @@ def forward(self, x: Tensor) -> Tensor: + def torch_compile(fn, *args, **kwargs): if hasattr(torch, 'compile'): - fn = torch.compile(fn, *args, **kwargs) + fn = torch.compile(fn, *args, **kwargs, dynamic=True, options={"shape_padding": True, "force_shape_pad": True}) return fn def swashl(x: Tensor) -> Tensor: From 995229b9befb3d5cbea9f04504ec99e83913cd5f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 12:26:23 +0800 Subject: [PATCH 0475/1191] Increase prob of printing MinProductLoss. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 30529cb353..e0391ece24 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1095,7 +1095,7 @@ def backward(ctx, ans_grad): else: product_deficit = (seq_len * min_product - product.sum(dim=1)).relu() - if random.random() < 0.0005: + if random.random() < 0.002: logging.info(f"MinProductLoss: {name}, limit={min_product}, product-deficit={product_deficit.mean() / seq_len}") grad = (weight * ans_grad).expand(product_deficit.numel()) From 15206fa707ddbf755e55d5ebf7d11708c26e3986 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 13:07:47 +0800 Subject: [PATCH 0476/1191] Reduce scale on min product loss by factor of 4 and decrease the minimum from .5 to .25. Was causing violation of cosine loss. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e0391ece24..1d677a3e7f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1095,7 +1095,7 @@ def backward(ctx, ans_grad): else: product_deficit = (seq_len * min_product - product.sum(dim=1)).relu() - if random.random() < 0.002: + if random.random() < 0.005: logging.info(f"MinProductLoss: {name}, limit={min_product}, product-deficit={product_deficit.mean() / seq_len}") grad = (weight * ans_grad).expand(product_deficit.numel()) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d5f6a27d99..778f486b7b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -840,7 +840,7 @@ def __init__( self.predict_loss = PredictLoss(dim) self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) - self.min_product_loss = MinProductLoss(0.5) + self.min_product_loss = MinProductLoss(0.25) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) @@ -929,7 +929,7 @@ def add_residual( self.cosine_loss(tot_permuted, aux_loss_scale * 0.25, src_key_padding_mask) + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask), + aux_loss_scale * 0.25, src_key_padding_mask), None) return tot From 60af23e9172a63109ffda2e24006da2dfd28eff5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 13:08:43 +0800 Subject: [PATCH 0477/1191] remove offset_cosine_loss --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 778f486b7b..522b835775 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -839,7 +839,6 @@ def __init__( self.predict_loss = PredictLoss(dim) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) self.min_product_loss = MinProductLoss(0.25) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) @@ -915,10 +914,6 @@ def add_residual( src_key_padding_mask: Optional[Tensor]): residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0, training=self.training) offset = (src - src_orig) * residual_scale - if aux_loss_scale: - offset = with_loss(offset, - self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale * 0.25, src_key_padding_mask), - None) offset = self.proj(offset, transpose=True) tot = src_orig_fulldim + offset From 9b964b2ec3d1ed6edf9032066722ca22d9394321 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 13:43:59 +0800 Subject: [PATCH 0478/1191] Reduce min_product_loss scale by factor of 10. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 522b835775..472d654c86 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -924,7 +924,7 @@ def add_residual( self.cosine_loss(tot_permuted, aux_loss_scale * 0.25, src_key_padding_mask) + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), - aux_loss_scale * 0.25, src_key_padding_mask), + aux_loss_scale * 0.025, src_key_padding_mask), None) return tot From 7efaed03f46905b6ca38103f3ab109c3a128a7bf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 13:47:57 +0800 Subject: [PATCH 0479/1191] Increase min_product from .25 to .5 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 472d654c86..5fcca27d7f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -839,7 +839,7 @@ def __init__( self.predict_loss = PredictLoss(dim) - self.min_product_loss = MinProductLoss(0.25) + self.min_product_loss = MinProductLoss(0.5) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From 9a81dd68c9d33b1512b4262bab72be80f8767b36 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 15:01:22 +0800 Subject: [PATCH 0480/1191] Reduce min_product_loss scale by another factor of 10. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 5fcca27d7f..15fa874dd7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -924,7 +924,7 @@ def add_residual( self.cosine_loss(tot_permuted, aux_loss_scale * 0.25, src_key_padding_mask) + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), - aux_loss_scale * 0.025, src_key_padding_mask), + aux_loss_scale * 0.0025, src_key_padding_mask), None) return tot From d577ba105ee7bf0e72139027d565d9e950e2d665 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 15:16:51 +0800 Subject: [PATCH 0481/1191] Fix initialization bug. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 15fa874dd7..bffb05aa46 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -832,7 +832,7 @@ def __init__( ) self.num_layers = num_layers - self.residual_scale = nn.Parameter(0.5 * torch.zeros(encoder_layer.embed_dim)) + self.residual_scale = nn.Parameter(0.5 * torch.ones(encoder_layer.embed_dim)) #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() From 65ee3ab0d6c02bbb070ed08d9a1d08974f413552 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 16:55:27 +0800 Subject: [PATCH 0482/1191] Introduce out_proj on downsampling layers. --- egs/librispeech/ASR/zipformer/zipformer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bffb05aa46..6e05ba83e2 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -178,6 +178,7 @@ def _to_tuple(x): num_encoder_layers[i], dim=downsampling_factor[i]*input_dim, pos_dim=pos_dim, + out_proj=(downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], ) encoders.append(encoder) @@ -816,6 +817,7 @@ def __init__( num_layers: int, dim: int, pos_dim: int, + out_proj: bool, ) -> None: super().__init__() @@ -843,6 +845,13 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear + # module. + if out_proj: + self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + self.out_proj.lr_scale = 0.75 + def forward( self, src: Tensor, @@ -902,6 +911,9 @@ def forward( else: mask = None + if hasattr(self, 'out_proj'): + src = self.out_proj(src) + return src, self.predict_loss(src, mask) From bfcdebf95cc6e964f56076dede37122038098192 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 17:02:34 +0800 Subject: [PATCH 0483/1191] restore warmup_batches from 1500 to default of 500. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 137972d752..2f706a036f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1421,7 +1421,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), warmup_batches=1500) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From efae008c236f083d3341ae5a5d640466d67055e7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 17:44:10 +0800 Subject: [PATCH 0484/1191] Adjust min_product threshold based on a formula. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6e05ba83e2..56b7b5d856 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -841,7 +841,11 @@ def __init__( self.predict_loss = PredictLoss(dim) - self.min_product_loss = MinProductLoss(0.5) + ratio = 1.2 # require the "passed-through" dims be larger by a factor of 1.2 tyan the bypassed dims. + d_yes = encoder_layer.embed_dim + d_no = dim - encoder_layer.embed_dim + min_product = (d_yes * ratio) / (d_yes * ratio + d_no) + self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From 7d9e0d57653d10c7a2a7880199db78dec6e6719c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 18:34:27 +0800 Subject: [PATCH 0485/1191] Reduce the proportion of the variance that the non-residual term must take. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 56b7b5d856..d9f5109d45 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -841,10 +841,9 @@ def __init__( self.predict_loss = PredictLoss(dim) - ratio = 1.2 # require the "passed-through" dims be larger by a factor of 1.2 tyan the bypassed dims. d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * ratio) / (d_yes * ratio + d_no) + min_product = (d_yes * 0.75) / (d_yes + d_no) self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From 9e0962b79439da643def8f5430834e1c1aa6e12f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Sep 2025 20:20:32 +0800 Subject: [PATCH 0486/1191] Make cosine loss scales 5 times smaller, with some refactoring. --- egs/librispeech/ASR/zapformer/train.py | 10 +++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2f706a036f..d41e37644c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -514,6 +514,14 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--aux-loss-scale", + type=float, + default=0.05, + help="Scale on auxiliary losses that are defined in the model, such " + "as cosine loss.", + ) + parser.add_argument( "--ctc-loss-scale", type=float, @@ -1188,7 +1196,7 @@ def save_bad_model(suffix: str = ""): batch=batch, is_training=True, spec_augment=spec_augment, - aux_loss_scale=get_scaler_scale(), + aux_loss_scale=get_scaler_scale() * params.aux_loss_scale, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d9f5109d45..7b0138b4fc 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -937,9 +937,9 @@ def add_residual( tot_permuted = tot.permute(1, 0, 2) tot = with_loss(tot, self.cosine_loss(tot_permuted, - aux_loss_scale * 0.25, src_key_padding_mask) + + aux_loss_scale, src_key_padding_mask) + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), - aux_loss_scale * 0.0025, src_key_padding_mask), + aux_loss_scale * 0.05, src_key_padding_mask), None) return tot @@ -1693,7 +1693,7 @@ def forward( if aux_loss_scale: x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), - aux_loss_scale * 0.25, + aux_loss_scale, mask=src_key_padding_mask), None) return x @@ -1783,7 +1783,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: x = self.in_proj(x) x = self.out_proj(x) - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale * 0.25, src_key_padding_mask), None) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) return x From 140ac0d941bdddc41e19a1a382b9e7d2980c0dc3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 16:17:42 +0800 Subject: [PATCH 0487/1191] Implement max attention qk and qp products using two separate MaxProductLoss() --- egs/librispeech/ASR/zipformer/scaling.py | 121 ++++++++++++++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 48 ++++---- 2 files changed, 145 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1d677a3e7f..73a092e685 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1028,7 +1028,9 @@ def forward(self, loss_scale: float, mask: Optional[Tensor] = None) -> Tensor: """ - Compute cosine-similarity loss that tries to keep distinct output vectors distinct. + Compute cosine-similarity loss that tries to make sure distinct output vectors + have inner products with small magnitude (after normalization), i.e. the cosine + of the angle between should be close to zero. x: Tensor of shape (batch_size, seq_len, num_channels) loss_scale: the scale with which the loss should be incorporated into the graph. @@ -1141,6 +1143,123 @@ def forward(self, loss_scale, self.name) +class MinProductLoss(nn.Module): + def __init__(self, + min_product: FloatLike): # e.g. 0.5 for min_product + super().__init__() + self.min_product = min_product + self.name = None + + def forward(self, + x: Tensor, + y: Tensor, + loss_scale: float, + mask: Optional[Tensor] = None) -> Tensor: + """ + Compute loss that tries to keep two embeddings in similar directions, used to + make sure that the bulk of the embedding goes through one branch. + + x: Tensor of shape (batch_size, seq_len, num_channels) + y: Tensor of shape (batch_size, seq_len, num_channels) + loss_scale: the scale with which the loss should be incorporated into the graph. + This should contain a factor of the grad_scale, if you are using GradScaler for + automatic mixed precision training (amp). + The loss will be summed over frames, and multiplied by this value. + mask: if supplied, mask of shape (batch_size, seq_len); + True means masked positions that will be ignored. + + Returns: + returns a scaled scalar loss value "ret" which should be incorporated + into the backprop graph by doing: + z = with_loss(z, ret, None) + where z is any quantity that will be used in calculating the main loss. + Ret will always be numerically equal to zero in the forward pass but + may behave as if it were nonzero for backprop purposes. + """ + return MinProductLossFunction.apply(x, y, mask, + float(self.min_product), + loss_scale, self.name) + + +class MaxProductLossFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, y: Tensor, max_product: float, weight: float, name: str): + ctx.save_for_backward(x, y) + ctx.name = name + ctx.weight = weight + ctx.max_product = max_product + return torch.tensor(0.0, device=x.device, dtype=x.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + x, y = ctx.saved_tensors + name = ctx.name # str + weight = ctx.weight # float + max_product = ctx.max_product # float + + with torch.enable_grad(): + x, y = x.detach(), y.detach() + x.requires_grad = True + y.requires_grad = True + + (batch_size, seq_len, num_channels) = x.shape + seq_len2 = y.shape[1] + indexes = torch.randint(0, seq_len2, (batch_size, seq_len, 1), device=x.device) + + y = torch.gather(y, 1, indexes.expand(*x.shape)) + + product = (x * y).sum(dim=-1).abs() + + excess_product = (product.sum(dim=1) - seq_len * max_product).relu() + + if random.random() < 0.001: + logging.info(f"MaxProduct: {name}, limit={max_product}, excess-product={excess_product.mean() / seq_len}") + + grad = (weight * ans_grad).expand(excess_product.numel()) + excess_product.backward(grad) + + return x.grad, y.grad, None, None, None + +class MaxProductLoss(nn.Module): + def __init__(self, + max_product: FloatLike): # e.g. 20.0 for max_product + super().__init__() + self.max_product = max_product + self.name = None + + def forward(self, + x: Tensor, + y: Tensor, + loss_scale: float) -> Tensor: + """ + Compute loss that limits the average dot product (without normalization) + between x, and (y, but randomly permuted on the sequence dimension). It is + intended for limiting dot-products of queries and keys. + + x: Tensor of shape (batch_size, seq_len, num_channels) + y: Tensor of shape (batch_size, seq_len2, num_channels) [seq_len2 does not have to equal seq_len]. + loss_scale: the scale with which the loss should be incorporated into the graph. + This should contain a factor of the grad_scale, if you are using GradScaler for + automatic mixed precision training (amp). We divide this by max_product, + so that it penalizes relative, not absolute, violations of the max-product + rule. + The loss will be summed over frames of x, and multiplied by this value. + + Returns: + returns a scaled scalar loss value "ret" which should be incorporated + into the backprop graph by doing: + z = with_loss(z, ret, None) + where z is any quantity that will be used in calculating the main loss. + Ret will always be numerically equal to zero in the forward pass but + may behave as if it were nonzero for backprop purposes. + """ + max_product = float(self.max_product) + return MaxProductLossFunction.apply(x, y, max_product, + loss_scale / max_product, + self.name) + class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7b0138b4fc..24788d6975 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -36,6 +36,7 @@ ChunkCausalDepthwiseConv1d, CosineSimilarityLoss, MinProductLoss, + MaxProductLoss, Dropout2, FloatLike, ScheduledFloat, @@ -650,6 +651,7 @@ def forward( pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale, ) src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -1306,9 +1308,6 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.attn_score_limit = ScheduledFloat((0.0, 10.0), (5000.0, 20.0)) - self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) - key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1338,12 +1337,17 @@ def __init__( self.copy_pos_query = Identity() self.copy_query = Identity() + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 2.5), (5000.0, 10.0), default=10.0)) + self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.5), (5000.0, 2.0), default=2.0)) + + def forward( self, x: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tensor: r""" Args: @@ -1391,6 +1395,14 @@ def forward( p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + if self.training: + k = with_loss(k, + self.qk_max_product(q.reshape(num_heads * batch_size, seq_len, query_head_dim), + q.permute(0, 1, 3, 2).reshape(num_heads * batch_size, seq_len, query_head_dim), + aux_loss_scale / num_heads), + None) + + attn_scores = torch.matmul(q, k) if True: @@ -1400,7 +1412,15 @@ def forward( pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( 2, 0, 3, 1 ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + # pos shape now: (head, {1 or batch_size}, pos_head_dim, seq_len2) + + if self.training: + pe = pos_emb.expand(num_heads, batch_size, pos_head_dim, seq_len2) + pe = pe.reshape(num_heads * batch_size, pos_head_dim, seq_len2).permute(0, 2, 1) + p = with_loss(p, + self.pos_max_product(p.reshape(num_heads * batch_size, seq_len, pos_head_dim), pe, + aux_loss_scale / num_heads), + None) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] @@ -1429,26 +1449,8 @@ def forward( storage_offset=pos_scores.stride(3) * (seq_len - 1), ) - attn_scores = attn_scores + pos_scores - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < float(self.attn_score_penalty_prob): - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=float(self.attn_score_limit), penalty=1.0e-04, name=self.name - ) + attn_scores = attn_scores + pos_scores assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) From 3d0b448fcbde90d2db4842628725461886103c06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 16:54:35 +0800 Subject: [PATCH 0488/1191] Replace whiten_keys in attention with cosine loss (power=0.7). --- egs/librispeech/ASR/zipformer/zipformer.py | 26 +++++++++++++++------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7b0138b4fc..272763bd96 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -650,6 +650,7 @@ def forward( pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale, ) src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -1322,12 +1323,9 @@ def __init__( bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) + + self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.7)) + # linear transformation for positional encoding. self.linear_pos = ScaledLinear( @@ -1337,6 +1335,7 @@ def __init__( # the following are for diagnostics only, see --print-diagnostics option self.copy_pos_query = Identity() self.copy_query = Identity() + self.copy_key = Identity() def forward( self, @@ -1344,6 +1343,7 @@ def forward( pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, ) -> Tensor: r""" Args: @@ -1379,13 +1379,21 @@ def forward( ) q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(k) # does nothing in the forward pass. [this may not really be needed due to the orthogonality constraint.] - p = self.copy_pos_query(p) # for diagnostics only, does nothing. + k = self.copy_key(k) + p = self.copy_pos_query(p) q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + if aux_loss_scale: + k = with_loss(k, + self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads, + key_padding_mask.repeat_interleave(num_heads, 1) if key_padding_mask is not None else None), + None) + + # time1 refers to target, time2 refers to source. q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) @@ -2131,12 +2139,14 @@ def _test_zipformer_main(causal: bool = False): f, lengths, predict_loss = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, ) f.sum().backward() c.eval() f = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, ) f # to remove flake8 warnings From 3d273ccbb1ea95c304440f1fd0688f589e20a86b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 17:09:51 +0800 Subject: [PATCH 0489/1191] Fix to usage of repeat_interleave --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 272763bd96..d11587db40 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1390,7 +1390,7 @@ def forward( k = with_loss(k, self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), aux_loss_scale / num_heads, - key_padding_mask.repeat_interleave(num_heads, 1) if key_padding_mask is not None else None), + key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None), None) From e73e8d5a1234118c65898f4a8d6bde6af0473bec Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 18:09:01 +0800 Subject: [PATCH 0490/1191] Decrease limits of max_product average values. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 24788d6975..f5a65fdae3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1337,8 +1337,8 @@ def __init__( self.copy_pos_query = Identity() self.copy_query = Identity() - self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 2.5), (5000.0, 10.0), default=10.0)) - self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.5), (5000.0, 2.0), default=2.0)) + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (20000.0, 4.0), default=10.0)) + self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.2), (20000.0, 1.0), default=2.0)) def forward( From affcf2451d60eb9d6e99d4202426d3e4bd1746ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 19:08:17 +0800 Subject: [PATCH 0491/1191] Decrease power in key_cosine_loss from .75 to .5. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d11587db40..c547908dcf 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1324,7 +1324,7 @@ def __init__( ) - self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.7)) + self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) # linear transformation for positional encoding. From 6c41e8daa40db872510d8c6a7c7cf5a0b6f28167 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 19:23:19 +0800 Subject: [PATCH 0492/1191] Fix bug in backprop --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 73a092e685..a6713f494e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1208,9 +1208,9 @@ def backward(ctx, ans_grad): seq_len2 = y.shape[1] indexes = torch.randint(0, seq_len2, (batch_size, seq_len, 1), device=x.device) - y = torch.gather(y, 1, indexes.expand(*x.shape)) + y_rand = torch.gather(y, 1, indexes.expand(*x.shape)) - product = (x * y).sum(dim=-1).abs() + product = (x * y_rand).sum(dim=-1).abs() excess_product = (product.sum(dim=1) - seq_len * max_product).relu() From 4b8c7a3ff9614e0d0ec6feca6040bfd4a8cd63ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 20:08:29 +0800 Subject: [PATCH 0493/1191] Increase max_product for position scores. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f5a65fdae3..77e43280be 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1338,7 +1338,7 @@ def __init__( self.copy_query = Identity() self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (20000.0, 4.0), default=10.0)) - self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.2), (20000.0, 1.0), default=2.0)) + self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (20000.0, 4.0), default=2.0)) def forward( From 4c1a34913861a1a4124b6756d130480f58169fc3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Sep 2025 20:31:30 +0800 Subject: [PATCH 0494/1191] Change schedules for max products in self-attn --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 77e43280be..365f8f7762 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1337,8 +1337,8 @@ def __init__( self.copy_pos_query = Identity() self.copy_query = Identity() - self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (20000.0, 4.0), default=10.0)) - self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (20000.0, 4.0), default=2.0)) + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (10000.0, 4.0), default=10.0)) + self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.5), (10000.0, 2.0), default=2.0)) def forward( From b459ef88d721b764d6f9f014c1863df73861dad8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Sep 2025 00:06:24 +0800 Subject: [PATCH 0495/1191] Tune schedules and final maximum for pos_max_product, qk_max_product. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 365f8f7762..7d1143d0a6 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1337,8 +1337,8 @@ def __init__( self.copy_pos_query = Identity() self.copy_query = Identity() - self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 1.0), (10000.0, 4.0), default=10.0)) - self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.5), (10000.0, 2.0), default=2.0)) + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.6), (20000.0, 6.0), default=5.0)) + self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.4), (20000.0, 4.0), default=5.0)) def forward( From 91fbfed7c48133081b1903014c735082d199e0af Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 2 Sep 2025 14:25:13 +0800 Subject: [PATCH 0496/1191] Reduce number of stacks by one, fewer total layers. --- egs/librispeech/ASR/zapformer/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d41e37644c..0d63bf5812 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,14 +185,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="3,5,7,5,4,7,5", + default="3,5,6,6,6,5", help="Number of zipformer encoder layers per stack, comma separated.", ) parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,8,8,4,2", + default="1,2,4,8,4,2", help="Downsampling factor for each stack of encoder layers.", ) @@ -213,21 +213,21 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-multiple", type=str, - default="3,3,3,3,3,3,3", + default="3,3,3,3,3,3", help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", ) parser.add_argument( "--num-heads", type=str, - default="4,4,4,8,8,4,4", + default="4,4,4,8,4,4", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) parser.add_argument( "--encoder-multiple", type=str, - default="4,6,9,12,12,9,6", + default="4,6,9,12,9,6", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31,31,15,15,15,15,31", + default="31,31,15,15,15,31", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) From 995c0631bcc42508a9ba0977daa82366315feeae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 2 Sep 2025 14:54:43 +0800 Subject: [PATCH 0497/1191] Reduce min_ratio for setting min_product from 0.75 to 0.5, allowing more bypass. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7b0138b4fc..bc984faa4b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -843,7 +843,7 @@ def __init__( d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * 0.75) / (d_yes + d_no) + min_product = (d_yes * 0.5) / (d_yes + d_no) self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From ca28c6e7c7fa8786396645bb585dd89bc237657e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Sep 2025 11:13:44 +0800 Subject: [PATCH 0498/1191] Separate weight computation for self_attn1 and self_attn2 but have both be done upfront. --- egs/librispeech/ASR/zipformer/zipformer.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bc984faa4b..e13b50fa44 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -588,7 +588,15 @@ def __init__( embed_dim, ) - self.self_attn_weights = RelPositionMultiheadAttentionWeights( + self.self_attn_weights1 = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + self.self_attn_weights2 = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, @@ -645,7 +653,13 @@ def forward( src_orig = src # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( + attn_weights1 = self.self_attn_weights1( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + attn_weights2 = self.self_attn_weights2( src, pos_emb=pos_emb, attn_mask=attn_mask, @@ -654,13 +668,13 @@ def forward( src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn1(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn2(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) From 75d4f198cfa5c17d5c58c404e68629a220cc6a4a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Sep 2025 17:32:26 +0800 Subject: [PATCH 0499/1191] Move predict_loss out of zipformer stacks after the encoder. --- egs/librispeech/ASR/zapformer/model.py | 27 ++++++++++++-- egs/librispeech/ASR/zipformer/zipformer.py | 42 ++++------------------ 2 files changed, 31 insertions(+), 38 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 7cf2a2781e..821f5aab5e 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch import Tensor from encoder_interface import EncoderInterface -from scaling import ScaledLinear, convert_num_channels +from scaling import ScaledLinear, convert_num_channels, PredictLoss from icefall.utils import add_sos, make_pad_mask, time_warp @@ -86,6 +86,8 @@ def __init__( self.encoder_embed = encoder_embed self.encoder = encoder + self.predict_loss = PredictLoss(encoder_dim) + self.use_transducer = use_transducer if use_transducer: # Modules for Transducer head @@ -159,14 +161,33 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, predict_loss = self.encoder(x, x_lens, src_key_padding_mask, specaug_mask=specaug_mask, - aux_loss_scale=aux_loss_scale) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale) + + predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask, specaug_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + return encoder_out, encoder_out_lens, predict_loss + + def compute_predict_loss(self, + encoder_out: Tensor, + src_key_padding_mask: Optional[Tensor], + specaug_mask: Optional[Tensor]) -> Tensor: + if src_key_padding_mask is not None and specaug_mask is not None: + mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) + elif src_key_padding_mask is not None: + mask = src_key_padding_mask.t().logical_not() + elif specaug_mask is not None: + mask = specaug_mask.t().logical_not() + else: + mask = None + return self.predict_loss(encoder_out, mask) + + def forward_ctc( self, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bc984faa4b..7518062499 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -43,7 +43,6 @@ convert_num_channels, limit_param_value, penalize_abs_values_gt, - PredictLoss, softmax, with_loss, ) @@ -218,7 +217,6 @@ def forward( x: Tensor, x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, - specaug_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ @@ -231,9 +229,6 @@ def forward( src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. - specaug_mask: - The mask that shows which frames were masked with specaug, of shape (batch_size, seq_len); - True means masked position. May be None. aux_loss_scale: If supplied, auxiliary losses such as CosineSimilarityLoss will be applied with this scale on the loss (note, these aux losses are @@ -261,19 +256,14 @@ def forward( attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) - specaug_mask = pad_mask(specaug_mask, x.shape[0]) num_stacks = len(self.downsampling_factor) - num_stacks = len(self.downsampling_factor) - - predict_loss = 0.0 - for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = downsample_by(x, ds) T = x.shape[0] - x, this_pred_loss = module( + x = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -281,11 +271,6 @@ def forward( if src_key_padding_mask is None else src_key_padding_mask[..., ::ds] ), - specaug_mask=( - None - if specaug_mask is None - else specaug_mask[..., ::ds] - ), attn_mask=(None if attn_mask is None else attn_mask[::ds, ::ds] @@ -293,7 +278,6 @@ def forward( aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) ) x = upsample_by(x, ds) - predict_loss += this_pred_loss * (ds / (self.output_downsampling_factor * num_stacks)) assert self.output_downsampling_factor == 2, self.output_downsampling_factor @@ -308,7 +292,7 @@ def forward( warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths, predict_loss + return x, lengths def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -839,8 +823,6 @@ def __init__( #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() - self.predict_loss = PredictLoss(dim) - d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim min_product = (d_yes * 0.5) / (d_yes + d_no) @@ -861,7 +843,6 @@ def forward( chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - specaug_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -905,19 +886,10 @@ def forward( # src = src_orig_fulldim + self.proj((src - src_orig) * self.residual_scale, transpose=True) # .. but with extra losses. - if src_key_padding_mask is not None and specaug_mask is not None: - mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) - elif src_key_padding_mask is not None: - mask = src_key_padding_mask.t().logical_not() - elif specaug_mask is not None: - mask = specaug_mask.t().logical_not() - else: - mask = None - if hasattr(self, 'out_proj'): src = self.out_proj(src) - return src, self.predict_loss(src, mask) + return src def add_residual( @@ -2125,20 +2097,20 @@ def _test_zipformer_main(causal: bool = False): left_context_frames=(64,), ) - batch_size = 6 # make it even, as PredictLoss requires even batch size. + batch_size = 6 seq_len = 21 # Just make sure the forward pass runs. - f, lengths, predict_loss = c( + f, lengths = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) f.sum().backward() c.eval() - f = c( + x_ = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) - f # to remove flake8 warnings + x_ # to remove flake8 warnings if __name__ == "__main__": From 8d76df3f1acb28b8d5223ea58991546d3ad81cbc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Sep 2025 17:48:05 +0800 Subject: [PATCH 0500/1191] Bug fix RE subsampling for predict_loss. --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 821f5aab5e..96e13210b3 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -164,7 +164,7 @@ def forward_encoder( encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, aux_loss_scale=aux_loss_scale) - predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask, specaug_mask) + predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) From 6cf925ecfe444b7e77960d32db710bf4dc14d9e1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Sep 2025 20:53:56 +0800 Subject: [PATCH 0501/1191] Bug fix to qk product limitation, use k not q. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7d1143d0a6..a85e389c9e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1398,7 +1398,7 @@ def forward( if self.training: k = with_loss(k, self.qk_max_product(q.reshape(num_heads * batch_size, seq_len, query_head_dim), - q.permute(0, 1, 3, 2).reshape(num_heads * batch_size, seq_len, query_head_dim), + k.permute(0, 1, 3, 2).reshape(num_heads * batch_size, seq_len, query_head_dim), aux_loss_scale / num_heads), None) From 8b3311780f8ad23f60e8fb2b5231b386150a2eda Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 5 Sep 2025 13:15:49 +0800 Subject: [PATCH 0502/1191] Revert 1069 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bc984faa4b..7b0138b4fc 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -843,7 +843,7 @@ def __init__( d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * 0.5) / (d_yes + d_no) + min_product = (d_yes * 0.75) / (d_yes + d_no) self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From 6a74a170affbaca42c90b6905d95f59e3d5882ce Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 5 Sep 2025 15:35:28 +0800 Subject: [PATCH 0503/1191] Introduce another self-attn module, before first feedforward. --- egs/librispeech/ASR/zipformer/zipformer.py | 31 +++++++--------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e13b50fa44..811bebc123 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -588,24 +588,16 @@ def __init__( embed_dim, ) - self.self_attn_weights1 = RelPositionMultiheadAttentionWeights( + self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, - ) - self.self_attn_weights2 = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, + num_heads=3 * num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, dropout=0.0, ) - self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] + self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] feedforward_dim = embed_dim * feedforward_multiple self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) @@ -653,28 +645,25 @@ def forward( src_orig = src # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights1 = self.self_attn_weights1( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - attn_weights2 = self.self_attn_weights2( + attn_weights1, attn_weights2, attn_weights3 = self.self_attn_weights( src, pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, - ) + ).chunk(3, dim=0) - src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) From edda0c0e296ccb9091c331e13d8437158de4cc31 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 5 Sep 2025 15:38:46 +0800 Subject: [PATCH 0504/1191] Introduce another self-attn module, before first feedforward. Have the heads partially overlap to save memory. --- egs/librispeech/ASR/zipformer/zipformer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 811bebc123..2789f6b4e0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -591,7 +591,7 @@ def __init__( self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, - num_heads=3 * num_heads, + num_heads=2 * num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, dropout=0.0, @@ -645,13 +645,16 @@ def forward( src_orig = src # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights1, attn_weights2, attn_weights3 = self.self_attn_weights( + attn_weights = self.self_attn_weights( src, pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, - ).chunk(3, dim=0) - + ) + num_heads = attn_weights.shape[0] // 2 # num heads per self_attn module + attn_weights1 = attn_weights[:num_heads] + attn_weights2 = attn_weights[num_heads//2:-num_heads//2] + attn_weights3 = attn_weights[num_heads:] src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) From 18d3d51d9bde20906ce91903be9a954ccd598200 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 5 Sep 2025 17:34:19 +0800 Subject: [PATCH 0505/1191] Have variable numbers of conv modules in zipformer layers. --- egs/librispeech/ASR/zipformer/zipformer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2789f6b4e0..bf77b1df60 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -168,6 +168,7 @@ def _to_tuple(x): feedforward_multiple=feedforward_multiple[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], + num_conv_modules=(2 if downsampling_factor[i] <= 2 else (1 if downsampling_factor[i] <= 4 else 0)), causal=causal, ) @@ -575,6 +576,7 @@ def __init__( feedforward_multiple: int, dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, + num_conv_modules: int = 2, causal: bool = False, randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.75)), ) -> None: @@ -606,8 +608,10 @@ def __init__( self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) - self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - for _ in range(2) ] + if num_conv_modules >= 2: + self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + if num_conv_modules >= 1: + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) self.scale_limiter = ScaleLimiter(max_var=2.0) @@ -662,13 +666,15 @@ def forward( src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + if hasattr(self, 'conv_module1'): + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + if hasattr(self, 'conv_module2'): + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) From 908f8d31fade23d773819677d6216e4a33858d6c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 5 Sep 2025 13:15:49 +0800 Subject: [PATCH 0506/1191] Revert 1069 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bf77b1df60..03b3c8dda3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -855,7 +855,7 @@ def __init__( d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * 0.5) / (d_yes + d_no) + min_product = (d_yes * 0.75) / (d_yes + d_no) self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From ad93b8fd14d2714918aa23a8904b2df855462b8a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 6 Sep 2025 13:34:37 +0800 Subject: [PATCH 0507/1191] Change formula to min_product = (d_yes * 0.75) / (d_yes + 0.75 * d_no), introducing factor of .75 on d_no, will affect most-bypassed layers the most. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 03b3c8dda3..763a49aa33 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -855,7 +855,7 @@ def __init__( d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * 0.75) / (d_yes + d_no) + min_product = (d_yes * 0.75) / (d_yes + 0.75 * d_no) self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From 6287a3f22a9f06772b73f5a99fc6272d9f28820d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 6 Sep 2025 13:35:36 +0800 Subject: [PATCH 0508/1191] Remove penalty_scale schedule on out_proj of zipformer encoders (compare branch 1074). --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 763a49aa33..0eb155c753 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -863,8 +863,7 @@ def __init__( # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear # module. if out_proj: - self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False, - penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) self.out_proj.lr_scale = 0.75 def forward( From 6877693f7a2b1f2ecf89452243088d23b87762a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 6 Sep 2025 16:11:12 +0800 Subject: [PATCH 0509/1191] Replace Whiten in Conv2dSubsmpling with CosineSimilarityLoss. --- egs/librispeech/ASR/zapformer/model.py | 2 +- egs/librispeech/ASR/zapformer/train.py | 3 +- egs/librispeech/ASR/zipformer/scaling.py | 26 +++++++++ egs/librispeech/ASR/zipformer/subsampling.py | 57 +++++++++++++------- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 96e13210b3..c40054747f 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -150,7 +150,7 @@ def forward_encoder( # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) - x, x_lens = self.encoder_embed(x, x_lens) + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 0d63bf5812..94eadc41ed 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -353,7 +353,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--use-ctc", type=str2bool, - default=False, + default=True, help="If True, use CTC head.", ) @@ -717,7 +717,6 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, out_channels=lookup(params, "embed_dim"), - dropout=0.0, ) return encoder_embed diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a6713f494e..d574212981 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1016,6 +1016,32 @@ def forward(self, x: Tensor, transpose: bool = False): weight = weight.t() return torch.nn.functional.linear(x, weight, self.bias) +def get_max_similarity(rank: int, power: float): + """ + For use when initializing CosineSimilarityLoss, this returns a value for + the "max_similarity" argument. + max_similarity is an upper limit we impose on the mean value of (x_i . x_j), + where i != j are two different sequence-position indexes and x_i and x_j are + activation vectors normalized to have unit length. + + rank: the dimension of the space, usually this is the num_channels, but if + we have just up-projected from a bottleneck, it would be the bottleneck + dimension. + power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean + we enforce the vector dimensions to be completely independent like Gaussian noise + (don't do this); if we set power=0.0 it would be equivalent to not having + the CosineSimilarityLoss at all. + + The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal + variable. If x consists of independent Gaussian noise of dimension D, with + variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" + would be close to a no-op for large D), then (x_i . x_j) would be distributed as + a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) + would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value + between this and 1, as a kind of heuristic limit on this max_similarity. + """ + return (0.7978845608 / (rank ** 0.5)) ** power + class CosineSimilarityLoss(nn.Module): def __init__(self, max_similarity: FloatLike): # e.g. 0.1 for max_similarity diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 67ea511c4e..3e2d8b502c 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -24,17 +24,44 @@ ScaleLimiter, ScaledLinear, ExpNorm, - Dropout3, FloatLike, ScaledConv2d, ScaleGrad, ScheduledFloat, SwashL, SwashR, - Whiten, + CosineSimilarityLoss, + with_loss, ) from torch import Tensor, nn +# TEMP: put this here, eventually we should import from scaling.py +def get_max_similarity(rank: int, power: float): + """ + For use when initializing CosineSimilarityLoss, this returns a value for + the "max_similarity" argument. + max_similarity is an upper limit we impose on the mean value of (x_i . x_j), + where i != j are two different sequence-position indexes and x_i and x_j are + activation vectors normalized to have unit length. + + rank: the dimension of the space, usually this is the num_channels, but if + we have just up-projected from a bottleneck, it would be the bottleneck + dimension. + power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean + we enforce the vector dimensions to be completely independent like Gaussian noise + (don't do this); if we set power=0.0 it would be equivalent to not having + the CosineSimilarityLoss at all. + + The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal + variable. If x consists of independent Gaussian noise of dimension D, with + variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" + would be close to a no-op for large D), then (x_i . x_j) would be distributed as + a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) + would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value + between this and 1, as a kind of heuristic limit on this max_similarity. + """ + return (0.7978845608 / (rank ** 0.5)) ** power + class ConvNeXt(nn.Module): """ @@ -154,7 +181,6 @@ def __init__( layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, - dropout: FloatLike = 0.1, ) -> None: """ Args: @@ -219,23 +245,12 @@ def __init__( self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=4.0) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), - prob=(0.025, 0.25), - grad_scale=0.02, - ) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.85)) - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. self.out_norm = ExpNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Subsample x. @@ -262,16 +277,18 @@ def forward( x = self.out(x) # Now x is of shape (N, (T-7)//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - if torch.jit.is_scripting() or torch.jit.is_tracing(): x_lens = (x_lens - 7) // 2 else: with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 + + key_padding_mask = torch.arange(0, x.shape[0], device=x.device) >= x_lens.unsqueeze(-1) + # key_padding_mask: (N, (T-7)//2) + x = with_loss(x, self.cosine_loss(x, aux_loss_scale, key_padding_mask), None) + x = self.out_norm(x) + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) return x, x_lens From 474be25809f8d5c8e4526b92859ddc86b9f01e00 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 6 Sep 2025 16:24:15 +0800 Subject: [PATCH 0510/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 ++ egs/librispeech/ASR/zipformer/subsampling.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d574212981..1c3a1fde5d 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -853,6 +853,8 @@ class CosineSimilarityLossFunction(torch.autograd.Function): @custom_fwd def forward(ctx, x: Tensor, mask: Optional[Tensor], max_similarity: float, weight: float, name: str): ctx.save_for_backward(x) + if mask is not None: + assert mask.shape == x.shape[:2], (list(mask.shape), list(x.shape)) ctx.mask = mask # mask will have no grad so it should be OK to store this way ctx.name = name ctx.weight = weight diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 3e2d8b502c..0bcea60dfc 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -284,7 +284,7 @@ def forward( warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - key_padding_mask = torch.arange(0, x.shape[0], device=x.device) >= x_lens.unsqueeze(-1) + key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) x = with_loss(x, self.cosine_loss(x, aux_loss_scale, key_padding_mask), None) x = self.out_norm(x) From 9499a180a91ec67c751654545904e9f13ef62f47 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 7 Sep 2025 14:48:29 +0800 Subject: [PATCH 0511/1191] Use rank-norm instead of mean and variance norm for PredictLoss. --- egs/librispeech/ASR/zipformer/scaling.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1c3a1fde5d..2f2a0a9542 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -537,28 +537,29 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, name: str, mask: Optional[Tensor]) -> Tensor: - # caution: now require input to be (seq, batch, channel) + # caution: require input to be (seq, batch, channel) batch_size = x.shape[1] if batch_size % 2 != 0: assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) - def mean_and_variance_norm(x): - mean = x.mean(dim=(0,1), keepdim=True) - x = x - 1.5 * mean # over-normalization. - eps = 1.0e-08 - stddev = ((x ** 2).mean(dim=(0, 1)) + eps).sqrt() - x = x / stddev - return x - + def rank_norm(x): + values, indexes = x.sort(dim=0) # sort on seq dim + # norm_rank: same shape as x + norm_rank = torch.linspace(-1., 1., x.shape[0], device=x.device, dtype=x.dtype) + norm_rank = norm_rank.reshape(-1, 1, 1) + norm_rank = norm_rank.repeat(1, x.shape[1], x.shape[2]) + x_norm = torch.empty_like(x) + x_norm.scatter_(dim=0, index=indexes, src=norm_rank) + return x_norm with torch.no_grad(): # get the indexes. project, then mean-and-variance-norm, then # take mx. x_proj = torch.matmul(x, proj_weight.t()) with torch.amp.autocast('cuda', enabled=False): - x_proj = mean_and_variance_norm(x_proj.to(torch.float)) + x_proj = rank_norm(x_proj.to(torch.float)) indexes = torch.max(x_proj, dim=-1)[1] From 10099da909ba9ea7cd5aab735933eecdf02be45a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 7 Sep 2025 16:36:47 +0800 Subject: [PATCH 0512/1191] Change prediction loss to be an average-square loss of predicting ranks, not a lobprob. --- egs/librispeech/ASR/zipformer/scaling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2f2a0a9542..4d3565ae6c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -560,24 +560,24 @@ def rank_norm(x): x_proj = torch.matmul(x, proj_weight.t()) with torch.amp.autocast('cuda', enabled=False): x_proj = rank_norm(x_proj.to(torch.float)) - indexes = torch.max(x_proj, dim=-1)[1] - indexes = torch.roll(indexes, batch_size // 2, 1) + x_proj = torch.roll(x_proj, batch_size // 2, 1) x_pred = predictor(x) - logprobs = x_pred.log_softmax(dim=-1) - loss = -torch.gather(logprobs, dim=-1, index=indexes.unsqueeze(-1)) + + loss = ((x_pred - x_proj) ** 2).mean(dim=-1) if random.random() < 0.002: logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") if mask is not None: mask = mask.to(x.dtype) - # we also swap the mask over the two copies of the data; the mask goes with the thing that + # note, this mask is True for *non*-masked positions. + # we swap the mask over the two copies of the data; the mask goes with the thing that # is predicted, not the thing we predict it from.. the idea being that we don't want to ask # the model to predict masked portions of the time sequence. mask = torch.roll(mask, batch_size // 2, 1) - loss = loss * mask.unsqueeze(-1) + loss = loss * mask return loss.sum() # we reduce with sum in what we return. From 7cef79d5682f206ad1fe5e9b381aa3c167e036e1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 7 Sep 2025 17:45:29 +0800 Subject: [PATCH 0513/1191] Reduce power for Conv2dSubsampling's CosineSimilarityLoss from 0.75 to 0.65. --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 0bcea60dfc..69df15072b 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -245,7 +245,7 @@ def __init__( self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=4.0) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.85)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) self.out_norm = ExpNorm(out_channels) From 2a81bd9da813bdfd6dab218c9622c70a1599fb25 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 7 Sep 2025 21:28:58 +0800 Subject: [PATCH 0514/1191] Gaussianize the ranks. --- egs/librispeech/ASR/zipformer/scaling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4d3565ae6c..df5791b035 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -544,10 +544,13 @@ def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." return torch.tensor(0.0, device=x.device) - def rank_norm(x): + def gauss_norm(x): + # normalize by gaussianizing on each dimension values, indexes = x.sort(dim=0) # sort on seq dim # norm_rank: same shape as x - norm_rank = torch.linspace(-1., 1., x.shape[0], device=x.device, dtype=x.dtype) + N = max(2, x.shape[0]) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[0], device=x.device, dtype=torch.float) + norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data norm_rank = norm_rank.reshape(-1, 1, 1) norm_rank = norm_rank.repeat(1, x.shape[1], x.shape[2]) x_norm = torch.empty_like(x) @@ -559,7 +562,7 @@ def rank_norm(x): # take mx. x_proj = torch.matmul(x, proj_weight.t()) with torch.amp.autocast('cuda', enabled=False): - x_proj = rank_norm(x_proj.to(torch.float)) + x_proj = gauss_norm(x_proj.to(torch.float)) x_proj = torch.roll(x_proj, batch_size // 2, 1) From cdeba3fc1c88d822feea59577b296d80505fb8f5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 7 Sep 2025 22:22:19 +0800 Subject: [PATCH 0515/1191] Remove penalty_scale schedule on out_proj of Zipformer2Encoder --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 265af51259..58f194a109 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -841,8 +841,7 @@ def __init__( # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear # module. if out_proj: - self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False, - penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) self.out_proj.lr_scale = 0.75 def forward( From 58e1df441c1a4298548b5ca04c23f4020089b93e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 11:30:22 +0800 Subject: [PATCH 0516/1191] Change formula to min_product = (d_yes * 0.75) / (d_yes + 0.75 * d_no), new factor of .75 in den, like deterministic_invertible1095conv. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 58f194a109..b4b78ddceb 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -833,7 +833,7 @@ def __init__( d_yes = encoder_layer.embed_dim d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * 0.75) / (d_yes + d_no) + min_product = (d_yes * 0.75) / (d_yes + 0.75 * d_no) self.min_product_loss = MinProductLoss(min_product) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) From 0cacb5f092f4da9361973b30687ea4c61a5559ba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 11:43:13 +0800 Subject: [PATCH 0517/1191] Introduce CosineLoss in zipformer encoder layers, before the residual, with power=0.75 for setting the max_similarity. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b4b78ddceb..483cda1dff 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -568,6 +568,9 @@ def __init__( self.name = None # will be set from training loop self.randomize_scale = copy.deepcopy(randomize_scale) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.75)) + # self.bypass implements layer skipping as well as learnable scale on a residual term; see its default values. self.residual = ResidualModule( embed_dim, @@ -658,6 +661,10 @@ def forward( src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = with_loss(src, + self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + None) + src = self.residual(src_orig, src) src = self.scale_limiter(src) From 87e33ed0af0175dd9d4483bac64b362ba95dc3c8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 17:39:03 +0800 Subject: [PATCH 0518/1191] Gaussianize targets for reconstruction_loss and make the loss be squared difference. --- egs/librispeech/ASR/zapformer/model.py | 75 ++++++-------------------- 1 file changed, 17 insertions(+), 58 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index c40054747f..c83d04aaea 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -561,6 +561,21 @@ def forward_reconstruction_loss(self, batch_size = log_mels.shape[0] num_mels = log_mels.shape[2] + + def gauss_norm(x): + # normalize by gaussianizing on each dimension + values, indexes = x.sort(dim=1) # sort on seq dim + N = max(2, x.shape[1]) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[0], device=x.device, dtype=torch.float) + norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data + norm_rank = norm_rank.reshape(1, -1, 1) + norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) + x_norm = torch.empty_like(x) + x_norm.scatter_(dim=0, index=indexes, src=norm_rank) + return x_norm + + log_mels = gauss_norm(log_mels) + pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) T_embed = pred_mels.shape[1] pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) @@ -586,11 +601,8 @@ def forward_reconstruction_loss(self, # this way of applying the padding mask is not really ideal in terms of normalization, # it will cause us to under-normalize a bit. diff = log_mels * pad_mask - pred_mels * pad_mask - # mean over sequence and mel-bin dims but not batch. - # this smooth_l1_loss_mod is intended to accomplish volume normalization at the - # sequence level, i.e. in case the differently-augmented signals have a difference in volume, - # which could happen due to musan augmentation. - loss = smooth_l1_loss_mod(diff, beta=1.0, norm_dims=(1, 2)) + + loss = (diff ** 2) # removing the masking logic since we now use the no-specaug reference sequence. ## masking. if it's different from the next item on both the frequency dim @@ -602,56 +614,3 @@ def forward_reconstruction_loss(self, loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. return loss - - - -def smooth_l1_loss_mod(diffs: Tensor, beta: float = 1.0, - norm_dims: Optional[Tuple[int]] = None): - """ - This is similar to : - loss = torch.nn.SmoothL1Loss(reduction='none', beta=beta) - loss(a, b) is similar to smooth_l1_loss_mod(a - b), - except that it does an optional normalization step that involves - subtracting a mean computed over 'norm_dims'. - """ - assert beta > 0 - def get_scale(diffs): - # torch.nn.SmoothL1Loss(reduction='none', beta=beta) is: - # l_n = 0.5 * (diff^2 / beta) if |diff| < beta - # else: |diff| - 0.5 / beta - diffs_abs = diffs.abs() - l2_loss = (0.5 / beta) * (diffs ** 2) - l1_loss = diffs.abs() - (0.5 * beta) - # 'scale' is a loss scale such that if we multiply l2_loss by it, - # we get the final loss. - scale = l1_loss.clamp(min=0.5 * beta) / l2_loss.clamp(min=0.5 * beta) - return scale.sqrt() - # ok, now we can treat the loss as (0.5 / beta) * diffs_scaled ** 2 - if norm_dims: - scale = get_scale(diffs) - offset = (scale * diffs).mean(dim=norm_dims, keepdim=True) / scale.mean(dim=norm_dims, keepdim=True) - diffs = diffs - offset - - loss = (0.5 / beta) * ((diffs * get_scale(diffs)) ** 2) - return loss - - - -def _test_smooth_l1_loss_mod(): - a = torch.randn(4, 50) - b = torch.randn(4, 50) + 10. * torch.randn(4, 1) - - beta = 2.0 - loss = torch.nn.SmoothL1Loss(reduction='none', beta=beta) - loss1 = loss(a, b) - loss2 = smooth_l1_loss_mod(a - b, beta=beta) - #print(f"loss1={loss1}, loss2={loss2}") - assert torch.allclose(loss1, loss2, atol=0.001) - - loss2_norm = smooth_l1_loss_mod(a - b, beta=beta, norm_dims=(1,)) - print(f"loss2-mean={loss2.mean()}, loss2_norm-mean={loss2_norm.mean()}") - assert loss2_norm.mean() <= loss2.mean() - - -if __name__ == '__main__': - _test_smooth_l1_loss_mod() From 4121c944ca7ce727afd4c4379da4899fec4f4b96 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 18:35:09 +0800 Subject: [PATCH 0519/1191] Bug fix --- egs/librispeech/ASR/zapformer/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index c83d04aaea..e4abf65a7c 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -566,12 +566,12 @@ def gauss_norm(x): # normalize by gaussianizing on each dimension values, indexes = x.sort(dim=1) # sort on seq dim N = max(2, x.shape[1]) - norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[0], device=x.device, dtype=torch.float) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[1], device=x.device, dtype=torch.float) norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data norm_rank = norm_rank.reshape(1, -1, 1) norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) x_norm = torch.empty_like(x) - x_norm.scatter_(dim=0, index=indexes, src=norm_rank) + x_norm.scatter_(dim=1, index=indexes, src=norm_rank) return x_norm log_mels = gauss_norm(log_mels) From 6a74c8dc7823a0e7ae1bd92103a588a407fa83c7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 22:56:48 +0800 Subject: [PATCH 0520/1191] Remove cosine loss from layers and put it instead after each encoder. --- egs/librispeech/ASR/zipformer/zipformer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ece5779b04..14fe2dfaba 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -571,8 +571,6 @@ def __init__( self.randomize_scale = copy.deepcopy(randomize_scale) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.75)) - # self.bypass implements layer skipping as well as learnable scale on a residual term; see its default values. self.residual = ResidualModule( embed_dim, @@ -667,10 +665,6 @@ def forward( src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = with_loss(src, - self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), - None) - src = self.residual(src_orig, src) src = self.scale_limiter(src) @@ -849,6 +843,7 @@ def __init__( min_product = (d_yes * 0.75) / (d_yes + 0.75 * d_no) self.min_product_loss = MinProductLoss(min_product) + self.encoder_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear @@ -930,6 +925,8 @@ def add_residual( tot = with_loss(tot, self.cosine_loss(tot_permuted, aux_loss_scale, src_key_padding_mask) + + self.encoder_cosine_loss(src, + aux_loss_scale, src_key_padding_mask) + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), aux_loss_scale * 0.05, src_key_padding_mask), None) From d990396b8f170e13c68cdfa1fa3c089e88bd6803 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 23:29:43 +0800 Subject: [PATCH 0521/1191] Implement stochastic depth with projections. --- egs/librispeech/ASR/zapformer/model.py | 43 ++++++++++++------- egs/librispeech/ASR/zipformer/zipformer.py | 49 +++++++++++++++------- 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index e4abf65a7c..082762ab94 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -132,7 +132,7 @@ def __init__( def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoder outputs. Args: x: @@ -146,11 +146,16 @@ def forward_encoder( Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). + encoder_out_sd: + Stochastic-depth version of encoder output + predict_loss: + Cross-sequence prediction loss value """ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) - x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + x, x_lens, x_sd = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + # x_sd is stochastic-depth version of x. # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") @@ -161,16 +166,17 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, - aux_loss_scale=aux_loss_scale) + encoder_out, encoder_out_lens, encoder_out_sd = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale) - predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) + predict_loss = (0.9 * self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) + + 0.1 * self.compute_predict_loss(encoder_out_sd, src_key_padding_mask[:, ::2], specaug_mask[:, ::2])) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, predict_loss + return encoder_out, encoder_out_lens, x_sd, predict_loss def compute_predict_loss(self, @@ -488,8 +494,8 @@ def forward( # Compute encoder outputs - encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, - aux_loss_scale=aux_loss_scale) + encoder_out, encoder_out_lens, encoder_out_sd, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -512,20 +518,23 @@ def forward( if self.use_ctc: targets = y.values if not self.training: - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, + ctc_loss, ctc_loss_sd = [ self.forward_ctc( + encoder_out=e, encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, - ) + ) for e in [encoder_out, encoder_out_sd] ] + ctc_loss = 0.9 * ctc_loss + 0.1 * ctc_loss_sd cr_loss = torch.empty(0) else: - ctc_loss, cr_loss = self.forward_cr_ctc( - encoder_out=encoder_out, + ret, ret_sd = [ self.forward_cr_ctc( + encoder_out=e, encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, - ) + ) for e in [encoder_out, encoder_out_sd] ] + ctc_loss = 0.9 * ret[0] + 0.1 * ret_sd[0] + cr_loss = 0.9 * ret[1] + 0.1 * ret_sd[1] else: ctc_loss = torch.empty(0) cr_loss = torch.empty(0) @@ -540,8 +549,10 @@ def forward( else: attention_decoder_loss = torch.empty(0) - reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, - encoder_out_lens) + reconstruction_loss = (0.9 * self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) + + 0.1 * self.forward_reconstruction_loss(x_no_specaug, encoder_out_sd, + encoder_out_lens)) return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 14fe2dfaba..8da50bf0d1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -220,7 +220,7 @@ def forward( x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor]: """ Args: x: @@ -240,8 +240,9 @@ def forward( - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. - - predict_loss, a cross-prediction loss of randomized codebooks, relying on the CR-CTC - structure of the batch. + - embeddings_sd, a "stochastic-depth" version of embeddings that + is projected using a separate projection from random stacks, + differnently chosen per sequence. """ chunk_size, left_context_chunks = self.get_chunk_info() orig_seq_len = x.shape[0] @@ -261,11 +262,20 @@ def forward( num_stacks = len(self.downsampling_factor) + x_sd = x + + def combine_sd(i, x_sd, this_x_sd): + replace_prob = 1 / (i + 2) + batch_size = x_sd.shape[1] + do_replace = (torch.rand(1, batch_size, 1, device=x_sd.device) < replace_prob).expand_as(x_sd) + return torch.where(do_replace, this_x_sd, x_sd) + for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = downsample_by(x, ds) + x_sd = downsample_by(x_sd, ds) T = x.shape[0] - x = module( + x, this_x_sd = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -279,13 +289,16 @@ def forward( ), aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) ) + x_sd = combine_sd(i, x_sd, this_x_sd) x = upsample_by(x, ds) - + x_sd = upsample_by(x_sd, ds) assert self.output_downsampling_factor == 2, self.output_downsampling_factor od = self.output_downsampling_factor x = downsample_by(x, od) x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + x_sd = downsample_by(x_sd, od) + x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 @@ -294,7 +307,7 @@ def forward( warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths + return x, lengths, x_sd def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -852,6 +865,10 @@ def __init__( self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) self.out_proj.lr_scale = 0.75 + # stochastic-depth proj. + self.sd_proj = nn.Linear(encoder_layer.embed_dim, dim) + + def forward( self, src: Tensor, @@ -859,7 +876,7 @@ def forward( attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -873,7 +890,9 @@ def forward( src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. - Returns: a Tensor with the same shape as src. + Returns: + (src, src_sd) + where src_sd is an alternative version of src for stochastic-depth, that does not see the bypass. """ pos_emb = self.encoder_pos(src) @@ -896,7 +915,7 @@ def forward( # randomize_factor can be viewed as a simple version of an # importance-sampling factor. - src = self.add_residual(src_orig_fulldim, src_orig, src, aux_loss_scale, src_key_padding_mask) + src, src_sd = self.add_residual(src_orig_fulldim, src_orig, src, aux_loss_scale, src_key_padding_mask) # The above is equivalent to: # src = src_orig_fulldim + self.proj((src - src_orig) * self.residual_scale, transpose=True) # .. but with extra losses. @@ -904,7 +923,7 @@ def forward( if hasattr(self, 'out_proj'): src = self.out_proj(src) - return src + return src, src_sd def add_residual( @@ -913,10 +932,11 @@ def add_residual( src_orig, src, aux_loss_scale: float, - src_key_padding_mask: Optional[Tensor]): + src_key_padding_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]: + # return: (tot, src_sd) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0, training=self.training) offset = (src - src_orig) * residual_scale - + src_sd = self.sd_proj(offset) offset = self.proj(offset, transpose=True) tot = src_orig_fulldim + offset @@ -931,7 +951,7 @@ def add_residual( aux_loss_scale * 0.05, src_key_padding_mask), None) - return tot + return tot, src_sd def streaming_forward( @@ -2123,11 +2143,12 @@ def _test_zipformer_main(causal: bool = False): batch_size = 6 seq_len = 21 # Just make sure the forward pass runs. - f, lengths = c( + f, lengths, f_sd = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), aux_loss_scale=1.0, ) + assert f.shape == f_sd.shape f.sum().backward() c.eval() x_ = c( From ae2a9441ba754cd28e743ed361b931e40b0d71bb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 23:44:51 +0800 Subject: [PATCH 0522/1191] Bug fix --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 082762ab94..9dc9b138c2 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -154,7 +154,7 @@ def forward_encoder( # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) - x, x_lens, x_sd = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) # x_sd is stochastic-depth version of x. # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") From 6f0ecb0597c3e277311738f2868a2316763b6305 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 8 Sep 2025 23:47:45 +0800 Subject: [PATCH 0523/1191] Fix bug that was in 1136. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8da50bf0d1..87d13c3ae4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -945,10 +945,10 @@ def add_residual( tot = with_loss(tot, self.cosine_loss(tot_permuted, aux_loss_scale, src_key_padding_mask) + - self.encoder_cosine_loss(src, - aux_loss_scale, src_key_padding_mask) + self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), - aux_loss_scale * 0.05, src_key_padding_mask), + aux_loss_scale * 0.05, src_key_padding_mask) + + self.encoder_cosine_loss(src.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), None) return tot, src_sd From 7bd2108b3e28a25438982ba51cc69080f1732c45 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 00:05:08 +0800 Subject: [PATCH 0524/1191] Bug fix --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 9dc9b138c2..b6ef035161 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -176,7 +176,7 @@ def forward_encoder( assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, x_sd, predict_loss + return encoder_out, encoder_out_lens, encoder_out_sd, predict_loss def compute_predict_loss(self, From ff2ae20be8642d62e908037314bfc9d74e9fbe8a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 00:08:44 +0800 Subject: [PATCH 0525/1191] Bug fix --- egs/librispeech/ASR/zapformer/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index b6ef035161..377bea8f6f 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -173,6 +173,7 @@ def forward_encoder( 0.1 * self.compute_predict_loss(encoder_out_sd, src_key_padding_mask[:, ::2], specaug_mask[:, ::2])) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + encoder_out_sd = encoder_out_sd.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) From 1619492c6ddc5a991e8cd17782059138a1a37056 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 10:19:24 +0800 Subject: [PATCH 0526/1191] Fix to decode.py --- egs/librispeech/ASR/zapformer/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 504d1d94d2..221f01297b 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -452,7 +452,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens, _predict_loss = model.forward_encoder(feature, feature_lens) + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] hyps = [] From 0a34a54262e4d299cf744e25aec8ea635eb30262 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 10:53:05 +0800 Subject: [PATCH 0527/1191] Remove min_product_loss and encoder_cosine_loss from Zipformer2Encoder. --- egs/librispeech/ASR/zipformer/zipformer.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 87d13c3ae4..0f22bfc7ed 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -851,12 +851,6 @@ def __init__( #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() - d_yes = encoder_layer.embed_dim - d_no = dim - encoder_layer.embed_dim - min_product = (d_yes * 0.75) / (d_yes + 0.75 * d_no) - self.min_product_loss = MinProductLoss(min_product) - - self.encoder_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear @@ -944,11 +938,7 @@ def add_residual( tot_permuted = tot.permute(1, 0, 2) tot = with_loss(tot, self.cosine_loss(tot_permuted, - aux_loss_scale, src_key_padding_mask) + - self.min_product_loss(tot_permuted, offset.permute(1, 0, 2), - aux_loss_scale * 0.05, src_key_padding_mask) + - self.encoder_cosine_loss(src.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask), + aux_loss_scale, src_key_padding_mask), None) return tot, src_sd From 69a2494f16ebf9cbd065b0dbb4415550f08352a7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 15:25:21 +0800 Subject: [PATCH 0528/1191] Reduce subsampling factor of central stack from 8 to 4, so there are 3 stacks with subsampling-factor=4. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 94eadc41ed..298d0efa19 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -192,7 +192,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,8,4,2", + default="1,2,4,4,4,2", help="Downsampling factor for each stack of encoder layers.", ) From 1c8d21777021129f6f81d95dfa17f849278b3665 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 20:01:59 +0800 Subject: [PATCH 0529/1191] Introduce cosine similarity loss on new contribution of each zipformer layer (after multiplying by residual scale). power=0.8. --- egs/librispeech/ASR/zipformer/zipformer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0f22bfc7ed..d1717913be 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -576,18 +576,14 @@ def __init__( cnn_module_kernel: int = 31, num_conv_modules: int = 2, causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.75)), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim self.name = None # will be set from training loop - self.randomize_scale = copy.deepcopy(randomize_scale) + self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - # self.bypass implements layer skipping as well as learnable scale on a residual term; see its default values. - self.residual = ResidualModule( - embed_dim, - ) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, @@ -678,7 +674,13 @@ def forward( src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = self.residual(src_orig, src) + residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) + offset = (src_orig - src) * residual_scale + src = src + offset + + src = with_loss(src, + self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + None) src = self.scale_limiter(src) @@ -906,8 +908,6 @@ def forward( src_key_padding_mask=src_key_padding_mask, aux_loss_scale=aux_loss_scale/num_layers, ) - # randomize_factor can be viewed as a simple version of an - # importance-sampling factor. src, src_sd = self.add_residual(src_orig_fulldim, src_orig, src, aux_loss_scale, src_key_padding_mask) # The above is equivalent to: From 41d8a493ca7215b952683a47d1e7327adf3fa7b3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 20:52:58 +0800 Subject: [PATCH 0530/1191] Have the cosine loss of the encoders be applied at the level of the scaled offset from the stack output. --- egs/librispeech/ASR/zipformer/zipformer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d1717913be..ef591065a7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -853,7 +853,7 @@ def __init__( #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear # module. @@ -930,17 +930,15 @@ def add_residual( # return: (tot, src_sd) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0, training=self.training) offset = (src - src_orig) * residual_scale - src_sd = self.sd_proj(offset) - offset = self.proj(offset, transpose=True) - tot = src_orig_fulldim + offset if aux_loss_scale: - tot_permuted = tot.permute(1, 0, 2) - tot = with_loss(tot, - self.cosine_loss(tot_permuted, - aux_loss_scale, src_key_padding_mask), - None) + offset = with_loss(offset, + self.cosine_loss(offset.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), + None) + src_sd = self.sd_proj(offset) + tot = src_orig_fulldim + self.proj(offset, transpose=True) return tot, src_sd From d60b2533c9aae82120ba8ea180a8632619170f89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 21:38:44 +0800 Subject: [PATCH 0531/1191] Make the cosine loss on the offset of the stack be in addition to the previous cosine loss, not instead of it. --- egs/librispeech/ASR/zipformer/zipformer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ef591065a7..2653b8343c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -850,10 +850,10 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(encoder_layer.embed_dim)) - #bypass_dim = dim - encoder_layer.embed_dim self.copy_bypass = Identity() self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear # module. @@ -931,14 +931,17 @@ def add_residual( residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0, training=self.training) offset = (src - src_orig) * residual_scale + tot = src_orig_fulldim + self.proj(offset, transpose=True) + if aux_loss_scale: - offset = with_loss(offset, - self.cosine_loss(offset.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask), - None) + tot = with_loss(tot, + self.offset_cosine_loss(offset.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask) + + self.cosine_loss(tot.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), + None) src_sd = self.sd_proj(offset) - tot = src_orig_fulldim + self.proj(offset, transpose=True) return tot, src_sd From af2387b7fc6ee562562c60bd929cdf0caf17f0e2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 22:13:58 +0800 Subject: [PATCH 0532/1191] Bug fix on dims to cosine losses --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2653b8343c..778743723c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -852,8 +852,8 @@ def __init__( self.copy_bypass = Identity() - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear # module. From 34de15926a9d080da451abb490a3cb4fcb8b6d01 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Sep 2025 23:32:43 +0800 Subject: [PATCH 0533/1191] Increase power of cosine loss on feedforward and conv modules from 0.5 to 0.7. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 778743723c..ef05ddfd04 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1783,7 +1783,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): initial_scale=0.5, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.5)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: @@ -2005,7 +2005,7 @@ def __init__( dropout_p=0.0, initial_scale=0.05, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.5)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.7)) def forward( From b24a06435e73630f5aa78f9269510db088018a84 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 00:00:16 +0800 Subject: [PATCH 0534/1191] Increase power of cosine_loss of feedforward module from 0.7 to 0.8. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ef05ddfd04..bb9b64c7ef 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1783,7 +1783,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): initial_scale=0.5, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.8)) def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: From d42465ce0167b2ddbc8fc0840951f489085d0d19 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 00:14:41 +0800 Subject: [PATCH 0535/1191] Decrease power of cosine_loss of conv module from 0.7 to 0.6. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bb9b64c7ef..62d5f345ab 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -2005,7 +2005,7 @@ def __init__( dropout_p=0.0, initial_scale=0.05, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.7)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.6)) def forward( From b51d44f3a97e4ff881bf085848bb4a140299c357 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 10:11:38 +0800 Subject: [PATCH 0536/1191] Reduce power of cosine loss in feedforward module from 0.8 to 0.7. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 62d5f345ab..038c241e19 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1783,7 +1783,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): initial_scale=0.5, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.8)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: From 2f7c72736a2b4734e21f4dad119485cec7da3119 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 11:55:20 +0800 Subject: [PATCH 0537/1191] Bug fix regarding src - src_orig being backwards. Affects back to 1152. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 038c241e19..a0a58c18ab 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -675,8 +675,8 @@ def forward( src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) - offset = (src_orig - src) * residual_scale - src = src + offset + offset = (src - src_orig) * residual_scale + src = src_orig + offset src = with_loss(src, self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), From 594b6b029d30e44cd2785379c3ef6b560da66dea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 16:20:26 +0800 Subject: [PATCH 0538/1191] Multiply aux_loss_scale of frontend by 0.1. --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 377bea8f6f..577400b3dd 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -154,7 +154,7 @@ def forward_encoder( # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) - x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=0.1*aux_loss_scale) # x_sd is stochastic-depth version of x. # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") From c50248f0eba3dbc565646e43f4a302a07a9484ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 17:14:49 +0800 Subject: [PATCH 0539/1191] Revert scale of 0.1 on aux_loss_scale of Conv2dSubsampling --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 577400b3dd..377bea8f6f 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -154,7 +154,7 @@ def forward_encoder( # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) - x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=0.1*aux_loss_scale) + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) # x_sd is stochastic-depth version of x. # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") From 80503f4c741239cfda61fa1a424b6bdc40d0a688 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Sep 2025 17:15:34 +0800 Subject: [PATCH 0540/1191] Introduce aux_loss_scale scale of 0.25 for batch indexes after 2000. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 298d0efa19..8958cd6c68 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1195,7 +1195,7 @@ def save_bad_model(suffix: str = ""): batch=batch, is_training=True, spec_augment=spec_augment, - aux_loss_scale=get_scaler_scale() * params.aux_loss_scale, + aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info From 99eac46ac09363ed74d881c4dcde0062cc5dcd07 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 11 Sep 2025 12:44:03 +0800 Subject: [PATCH 0541/1191] Implement stochastic depth with random replacement, not weighted losses. --- egs/librispeech/ASR/zapformer/model.py | 58 +++++++++++----------- egs/librispeech/ASR/zapformer/train.py | 9 ++++ egs/librispeech/ASR/zipformer/zipformer.py | 47 ++++++++++-------- 3 files changed, 66 insertions(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 377bea8f6f..1d768d1552 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -131,8 +131,8 @@ def __init__( def forward_encoder( - self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute encoder outputs. Args: x: @@ -140,22 +140,23 @@ def forward_encoder( x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x` before padding. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + Returns: encoder_out: Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). - encoder_out_sd: - Stochastic-depth version of encoder output - predict_loss: - Cross-sequence prediction loss value """ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) - # x_sd is stochastic-depth version of x. # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") @@ -166,18 +167,17 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, encoder_out_sd = self.encoder(x, x_lens, src_key_padding_mask, - aux_loss_scale=aux_loss_scale) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale, + sd_prob=(0.1 if self.training else 0.0)) - predict_loss = (0.9 * self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) + - 0.1 * self.compute_predict_loss(encoder_out_sd, src_key_padding_mask[:, ::2], specaug_mask[:, ::2])) + predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - encoder_out_sd = encoder_out_sd.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, encoder_out_sd, predict_loss + return encoder_out, encoder_out_lens, predict_loss def compute_predict_loss(self, @@ -398,6 +398,7 @@ def forward( time_warp_factor: Optional[int] = 80, num_copies: int = 1, aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -432,6 +433,11 @@ def forward( num_copies: the number of copies of the same data that are in the batch, e.g. 1, 2 or 3; affects CRCTC, spec-augment, etc. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. Returns: Return the transducer losses, CTC loss, AED loss, @@ -495,8 +501,9 @@ def forward( # Compute encoder outputs - encoder_out, encoder_out_lens, encoder_out_sd, predict_loss = self.forward_encoder(x, x_lens, - aux_loss_scale=aux_loss_scale) + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale, + sd_prob=sd_prob) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -519,23 +526,20 @@ def forward( if self.use_ctc: targets = y.values if not self.training: - ctc_loss, ctc_loss_sd = [ self.forward_ctc( - encoder_out=e, + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, - ) for e in [encoder_out, encoder_out_sd] ] - ctc_loss = 0.9 * ctc_loss + 0.1 * ctc_loss_sd + ) cr_loss = torch.empty(0) else: - ret, ret_sd = [ self.forward_cr_ctc( - encoder_out=e, + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, - ) for e in [encoder_out, encoder_out_sd] ] - ctc_loss = 0.9 * ret[0] + 0.1 * ret_sd[0] - cr_loss = 0.9 * ret[1] + 0.1 * ret_sd[1] + ) else: ctc_loss = torch.empty(0) cr_loss = torch.empty(0) @@ -550,10 +554,8 @@ def forward( else: attention_decoder_loss = torch.empty(0) - reconstruction_loss = (0.9 * self.forward_reconstruction_loss(x_no_specaug, encoder_out, - encoder_out_lens) + - 0.1 * self.forward_reconstruction_loss(x_no_specaug, encoder_out_sd, - encoder_out_lens)) + reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 8958cd6c68..ffffded4bc 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -550,6 +550,14 @@ def get_parser(): help="Prediction of random k-means after widest zipformer layer" ) + parser.add_argument( + "--stochastic-depth-prob", + type=float, + default=0.1, + help="Probability of using a randomly chosen stack output during training, instead of " + "final output." + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -1008,6 +1016,7 @@ def compute_loss( time_warp_factor=80, # for specaug num_copies=num_copies, aux_loss_scale=aux_loss_scale, + sd_prob=(params.stochastic_depth_prob if is_training else 0.0), ) loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a0a58c18ab..0698dbd23f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -220,7 +220,8 @@ def forward( x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, - ) -> Tuple[Tensor, Tensor, Tensor]: + sd_prob: float = 0.0, + ) -> Tuple[Tensor, Tensor]: """ Args: x: @@ -235,14 +236,17 @@ def forward( If supplied, auxiliary losses such as CosineSimilarityLoss will be applied with this scale on the loss (note, these aux losses are reduced via summation over frames.) + sd_prob: + Stochastic-depth prob: with this probability we replace the final output + with the output of a randomly chosen stack (including the 'zero stack' which + means the original input x). Each stack except the 'zero stack' has a + separate output projection for stochastic depth, that only sees the + "non-bypass part", i.e. its encoder stack without the residual. Returns: - Return a tuple containing 4 tensors: + Return (embeddings_lengths), where: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. - - embeddings_sd, a "stochastic-depth" version of embeddings that - is projected using a separate projection from random stacks, - differnently chosen per sequence. """ chunk_size, left_context_chunks = self.get_chunk_info() orig_seq_len = x.shape[0] @@ -264,16 +268,14 @@ def forward( x_sd = x - def combine_sd(i, x_sd, this_x_sd): - replace_prob = 1 / (i + 2) - batch_size = x_sd.shape[1] - do_replace = (torch.rand(1, batch_size, 1, device=x_sd.device) < replace_prob).expand_as(x_sd) - return torch.where(do_replace, this_x_sd, x_sd) + def randomly_choose_seqs(x, this_x, prob: float): + batch_size = x.shape[1] + do_replace = (torch.rand(1, batch_size, 1, device=x.device) < prob).expand_as(x) + return torch.where(do_replace, this_x, x) for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = downsample_by(x, ds) - x_sd = downsample_by(x_sd, ds) T = x.shape[0] x, this_x_sd = module( x, @@ -289,16 +291,15 @@ def combine_sd(i, x_sd, this_x_sd): ), aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) ) - x_sd = combine_sd(i, x_sd, this_x_sd) x = upsample_by(x, ds) - x_sd = upsample_by(x_sd, ds) + if sd_prob: + x_sd = randomly_choose_seqs(x_sd, upsample_by(this_x_sd, ds), 1. / (2. + i)) + assert self.output_downsampling_factor == 2, self.output_downsampling_factor od = self.output_downsampling_factor x = downsample_by(x, od) x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding - x_sd = downsample_by(x_sd, od) - x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 @@ -307,7 +308,12 @@ def combine_sd(i, x_sd, this_x_sd): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths, x_sd + if sd_prob: + x_sd = downsample_by(x_sd, od) + x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + x = randomly_choose_seqs(x, x_sd, sd_prob) + + return x, lengths def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -887,8 +893,8 @@ def forward( masked position. May be None. Returns: - (src, src_sd) - where src_sd is an alternative version of src for stochastic-depth, that does not see the bypass. + (out, out_sd), both of the same shape as src, + where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. """ pos_emb = self.encoder_pos(src) @@ -2134,18 +2140,19 @@ def _test_zipformer_main(causal: bool = False): batch_size = 6 seq_len = 21 # Just make sure the forward pass runs. - f, lengths, f_sd = c( + f, lengths = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), aux_loss_scale=1.0, + sd_prob=0.1, ) - assert f.shape == f_sd.shape f.sum().backward() c.eval() x_ = c( torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), aux_loss_scale=1.0, + sd_prob=0.1, ) x_ # to remove flake8 warnings From 608a9737bc9edf411f5816a5267e0f800e1e4ced Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 16 Sep 2025 17:15:01 +0800 Subject: [PATCH 0542/1191] In individual stacks, implement a more general form of bypass that allows us to see intermediate layers. --- egs/librispeech/ASR/zipformer/zipformer.py | 48 +++++++++++----------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0698dbd23f..df1bc93ca8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -854,7 +854,10 @@ def __init__( ) self.num_layers = num_layers - self.residual_scale = nn.Parameter(0.5 * torch.ones(encoder_layer.embed_dim)) + self.residual_scales = nn.Parameter( + torch.cat([ 0.25 * torch.ones(1, encoder_layer.embed_dim), + (0.25 / (num_layers - 1) ) * torch.ones(num_layers - 1, encoder_layer.embed_dim)], + dim=0)) self.copy_bypass = Identity() @@ -904,8 +907,13 @@ def forward( num_layers = len(self.layers) src_orig = src + src_with_bypass = 0.0 for i, mod in enumerate(self.layers): + residual_scale = limit_param_value(self.residual_scales[i], min=0.0, + max=0.9 if i == 0 else 1. / num_layers) + + src_with_bypass = src_with_bypass + self.residual_scales[i] * src src = mod( src, pos_emb, @@ -915,40 +923,30 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) - src, src_sd = self.add_residual(src_orig_fulldim, src_orig, src, aux_loss_scale, src_key_padding_mask) - # The above is equivalent to: - # src = src_orig_fulldim + self.proj((src - src_orig) * self.residual_scale, transpose=True) - # .. but with extra losses. - - if hasattr(self, 'out_proj'): - src = self.out_proj(src) + residual_scale = limit_param_value(1. - self.residual_scales.sum(dim=0), + min=0.1, max=1.0) + src_with_bypass = src_with_bypass + self.residual_scales[i] * src - return src, src_sd + offset = src_with_bypass - src_orig - - def add_residual( - self, - src_orig_fulldim, - src_orig, - src, - aux_loss_scale: float, - src_key_padding_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]: - # return: (tot, src_sd) - residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0, training=self.training) - offset = (src - src_orig) * residual_scale - - tot = src_orig_fulldim + self.proj(offset, transpose=True) + src = src_orig_fulldim + self.proj(offset, transpose=True) + # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, + # because of some identities involving orthogonal matrices. if aux_loss_scale: - tot = with_loss(tot, + src = with_loss(src, self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask) + - self.cosine_loss(tot.permute(1, 0, 2), + self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) src_sd = self.sd_proj(offset) - return tot, src_sd + + if hasattr(self, 'out_proj'): + src = self.out_proj(src) + + return src, src_sd def streaming_forward( From f24b81cb016c20979fdf04ccd9d3511df4fde06e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 17 Sep 2025 13:02:05 +0800 Subject: [PATCH 0543/1191] Bug fix to the residual scales, make it apply the limits and constraints and fix bug for last layer. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index df1bc93ca8..23d47daf23 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -911,9 +911,8 @@ def forward( for i, mod in enumerate(self.layers): residual_scale = limit_param_value(self.residual_scales[i], min=0.0, - max=0.9 if i == 0 else 1. / num_layers) - - src_with_bypass = src_with_bypass + self.residual_scales[i] * src + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src src = mod( src, pos_emb, @@ -925,7 +924,7 @@ def forward( residual_scale = limit_param_value(1. - self.residual_scales.sum(dim=0), min=0.1, max=1.0) - src_with_bypass = src_with_bypass + self.residual_scales[i] * src + src_with_bypass = src_with_bypass + residual_scale * src offset = src_with_bypass - src_orig From e4a34056d1eecec24ad6953fa7f0c2676218f9f0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 17 Sep 2025 21:22:41 +0800 Subject: [PATCH 0544/1191] Remove the sum-to-one constraint on residuals and make last residual scale free. --- egs/librispeech/ASR/zipformer/zipformer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 23d47daf23..0b86763f88 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -855,8 +855,9 @@ def __init__( self.num_layers = num_layers self.residual_scales = nn.Parameter( - torch.cat([ 0.25 * torch.ones(1, encoder_layer.embed_dim), - (0.25 / (num_layers - 1) ) * torch.ones(num_layers - 1, encoder_layer.embed_dim)], + torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), + (0.25 / (num_layers - 1) ) * torch.ones(num_layers - 1, encoder_layer.embed_dim), + 0.75 * torch.ones(1, encoder_layer.embed_dim) ], dim=0)) self.copy_bypass = Identity() @@ -907,11 +908,12 @@ def forward( num_layers = len(self.layers) src_orig = src - src_with_bypass = 0.0 + + residual_scale = limit_param_value(self.residual_scales[0], + min=-1.0, max=0.0) + src_with_bypass = residual_scale * src for i, mod in enumerate(self.layers): - residual_scale = limit_param_value(self.residual_scales[i], min=0.0, - max=1.0) src_with_bypass = src_with_bypass + residual_scale * src src = mod( src, @@ -921,12 +923,13 @@ def forward( src_key_padding_mask=src_key_padding_mask, aux_loss_scale=aux_loss_scale/num_layers, ) + residual_scale = limit_param_value(self.residual_scales[i + 1], + min=0.0 if i + 1 < num_layers else 0.1, + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src - residual_scale = limit_param_value(1. - self.residual_scales.sum(dim=0), - min=0.1, max=1.0) - src_with_bypass = src_with_bypass + residual_scale * src - offset = src_with_bypass - src_orig + offset = src_with_bypass src = src_orig_fulldim + self.proj(offset, transpose=True) # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, From 6a37b2c77e6e84829a96c6f307a33ea8a6008b75 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 17 Sep 2025 22:48:20 +0800 Subject: [PATCH 0545/1191] Change sd_prob from 0.1 to 0.0; remove out_proj from some Zipformer2Encoder modules; have one, not zero, conv_module if downsampling_factor >= 8; make central downsampling_factor be 8 not 4. --- egs/librispeech/ASR/zapformer/model.py | 2 +- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 1d768d1552..278e498032 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -169,7 +169,7 @@ def forward_encoder( encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, aux_loss_scale=aux_loss_scale, - sd_prob=(0.1 if self.training else 0.0)) + sd_prob=0.0) predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ffffded4bc..18b842ca7c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -192,7 +192,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,4,4,2", + default="1,2,4,8,4,2", help="Downsampling factor for each stack of encoder layers.", ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0b86763f88..f5234a8821 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -168,7 +168,7 @@ def _to_tuple(x): feedforward_multiple=feedforward_multiple[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], - num_conv_modules=(2 if downsampling_factor[i] <= 2 else (1 if downsampling_factor[i] <= 4 else 0)), + num_conv_modules=(2 if downsampling_factor[i] <= 2 else 1), causal=causal, ) @@ -179,7 +179,7 @@ def _to_tuple(x): num_encoder_layers[i], dim=downsampling_factor[i]*input_dim, pos_dim=pos_dim, - out_proj=(downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], + out_proj=False, # (downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], ) encoders.append(encoder) From d6aee09f16872ce3f39ec2798e536bc726540a69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 18 Sep 2025 11:05:55 +0800 Subject: [PATCH 0546/1191] Limit residual_scales[0] to -1..-0.5, not -1..0. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f5234a8821..55f06809d5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -910,7 +910,7 @@ def forward( src_orig = src residual_scale = limit_param_value(self.residual_scales[0], - min=-1.0, max=0.0) + min=-1.0, max=-0.5) src_with_bypass = residual_scale * src for i, mod in enumerate(self.layers): From 26e012f08022044af0e166268124fc18396257ae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Sep 2025 12:22:08 +0800 Subject: [PATCH 0547/1191] Change initialization of residual scales to give less weight to last layer; reduce minimum of last layer's weight to be 0.05 instead of 0.1. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f5234a8821..bcb3de3c79 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -856,8 +856,7 @@ def __init__( self.residual_scales = nn.Parameter( torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), - (0.25 / (num_layers - 1) ) * torch.ones(num_layers - 1, encoder_layer.embed_dim), - 0.75 * torch.ones(1, encoder_layer.embed_dim) ], + (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim)], dim=0)) self.copy_bypass = Identity() @@ -924,7 +923,7 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.0 if i + 1 < num_layers else 0.1, + min=0.05 if i + 1 < num_layers else 0.1, max=1.0) src_with_bypass = src_with_bypass + residual_scale * src From 43646b1fe96f095067d7daa1e9834c495bcf460d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Sep 2025 12:50:51 +0800 Subject: [PATCH 0548/1191] Bug fix to stop src_with_bypass from being added twice. --- egs/librispeech/ASR/zipformer/zipformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bcb3de3c79..7e831ce851 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -913,7 +913,6 @@ def forward( src_with_bypass = residual_scale * src for i, mod in enumerate(self.layers): - src_with_bypass = src_with_bypass + residual_scale * src src = mod( src, pos_emb, From db1035bc3fa4cacf06398cea2f8fd895ad9fa1f3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Sep 2025 13:01:40 +0800 Subject: [PATCH 0549/1191] Fix adding-terms-twice bug in residual scales; change initialization to give less weight to last layer, i.e. no specially high weight for last layer. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 55f06809d5..4967920b62 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -856,8 +856,7 @@ def __init__( self.residual_scales = nn.Parameter( torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), - (0.25 / (num_layers - 1) ) * torch.ones(num_layers - 1, encoder_layer.embed_dim), - 0.75 * torch.ones(1, encoder_layer.embed_dim) ], + (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], dim=0)) self.copy_bypass = Identity() @@ -914,7 +913,6 @@ def forward( src_with_bypass = residual_scale * src for i, mod in enumerate(self.layers): - src_with_bypass = src_with_bypass + residual_scale * src src = mod( src, pos_emb, From bf50a1323be95ee6eba29c347eb4ba10d853bd12 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Sep 2025 13:16:36 +0800 Subject: [PATCH 0550/1191] Bug fix for limits to residual_scale --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7e831ce851..65f115e4ca 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -922,7 +922,7 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.05 if i + 1 < num_layers else 0.1, + min=0.0 if i + 1 < num_layers else 0.05, max=1.0) src_with_bypass = src_with_bypass + residual_scale * src From e7a27d744e9550560bba2a245274666c1bce73d9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Sep 2025 02:56:35 +0800 Subject: [PATCH 0551/1191] Change optim.py so that TransformedAdam uses abs, not rms, values for normalization; adjust LRs inside TransformedAdam so we do not have to re-tunen them. --- egs/librispeech/ASR/zipformer/optim.py | 113 +++++++++++++++---------- 1 file changed, 70 insertions(+), 43 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9693c61691..ac4dd403e2 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -167,6 +167,19 @@ def momentum_step(group, p, state, grad): stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) state["delta"] = stored_delta + # 1.2533141373155001 is sqrt(pi/2) which is a correction factor for the + # ratio of (rms value / abs value) of a normal distribution, made when we + # switched from using rms values to abs value for purposes of scaling. This + # does not apply to scalar parameters (p.numel() == p.shape[0], dimension 0 + # is the same-sized-parameter-tensor batch dimension), which are not subject + # to scaling by inverse-absolute-values. The update is going to get + # multiplied by the mean-absolute-value, i.e. the scaling factor, which is + # equal to sqrt(2/pi) times the rms value for normally distributed data, and + # we want the step size to be the same as before for normally distributed + # data, which means we need to multiply by sqrt(pi/2). + lr = (1.2533141373155001 * lr if p.numel() > p.shape[0] else lr) + + stored_delta.mul_(beta1).add_(delta) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) @@ -187,13 +200,16 @@ def forward_transform_param(group, p): return p.reshape(batch_size, 1) / group["scalar_lr_scale"] is_weight = (p.ndim > 2) - min_rms = group["weight_min_rms"] if is_weight else group["bias_min_rms"] + # 0.7978845608028654 is sqrt(2/pi) which is a correction factor for the ratio of (abs value / rms value) + # of a normal distribution, made when we switched from using rms values to abs value for purposes + # of scaling. + min_scale = 0.7978845608028654 * (group["weight_min_scale"] if is_weight else group["bias_min_scale"]) p_flat = p.reshape(batch_size, numel) - sumsq = (p_flat ** 2).sum(dim=1, keepdim=True) - min_sumsq = (min_rms ** 2) * numel # if sumsq is less than this we pad with an extra element. - sumsq_clamped = sumsq.clamp(min=min_sumsq) - pad = (sumsq_clamped - sumsq).sqrt() - scale = (sumsq_clamped / numel).sqrt() # must be nonzero thanks to min_rms + abs_sum = p_flat.abs().sum(dim=1, keepdim=True) + min_abs_sum = min_scale * numel # if sumsq is less than this we pad with an extra element. + abs_sum_clamped = abs_sum.clamp(min=min_abs_sum) + pad = (abs_sum_clamped - abs_sum) + scale = (abs_sum_clamped / numel) # must be nonzero thanks to min_abs_sum # scaling_lr_scale is to control the learning-rate of scaling factors. # log_scale controls the overall scale of this tensor @@ -209,12 +225,21 @@ def reverse_transform_param(group, p, orig_shape): # numel is num elements of each parameter tensor in the batch. numel = p.shape[1] - 2 p_padded = p[:, :numel+1] # orig tensor plus one padding element - p_padded = p_padded / ((p_padded ** 2).sum(dim=1, keepdim=True) / numel).sqrt() # normalize rms to 1. + # the next line normalizes the scale to 1, because the update step will have + # changed it slightly versus the normalized state that forward_transform_param + # put it into. The correction factor (numel + 1) / numel is to account + # for the fact that it's actuallty the sum() / numel that should equal 1, + # but we prefer to use mean to avoid out-of-range numerical errors for large tensors + # if this code gets used in fp16 in the future. + p_padded = p_padded / (p_padded.abs().mean(dim=1, keepdim=True) * ((numel + 1) / numel)) is_weight = (len(orig_shape) > 2) - max_rms = group["weight_max_rms"] if is_weight else group["bias_max_rms"] - min_rms = group["weight_min_rms"] if is_weight else group["bias_min_rms"] - scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]).exp().clamp(min=min_rms, max=max_rms) + # 0.7978845608028654 is sqrt(2/pi) which is a correction factor for the ratio of (abs value / rms value) + # of a normal distribution, made when we switched from using rms values to abs value for purposes + # of scaling. + max_scale = 0.7978845608028654 * (group["weight_max_scale"] if is_weight else group["bias_max_scale"]) + min_scale = 0.7978845608028654 * (group["weight_min_scale"] if is_weight else group["bias_min_scale"]) + scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]).exp().clamp(min=min_scale, max=max_scale) q = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. q = q.reshape(*orig_shape) @@ -362,11 +387,13 @@ class TransformedAdam(BatchedOptimizer): would be a the scaling factor on the learning rate of p_scale. scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. eps: A general-purpose epsilon to prevent division by zero - weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of - learning the scale on the parameters. Weight tensors are defined - as anything with more than one element and ndim > 1. - bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with - more than one element and exactly one tensor dimension i.e. ndim == 1. + weight_min_scale, weight_max_scale: Minimum and maximum respectively of weight tensor + scales (mean-absolute-value), for purposes of + learning the scale on the parameters. Weight tensors, as distinct from bias + tensors and scalars, are defined as anything with more than one element and ndim > 1. + bias_min_scale, bias_max_scale: Minimum and maximum respetively of bias tensor scales, + defined as anything with more than one element and exactly one tensor dimension i.e. + ndim == 1. debug_interval: if >0, write some statistics to tensorboard every this-many steps. """ def __init__( @@ -380,10 +407,10 @@ def __init__( scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, - weight_min_rms=0.005, - weight_max_rms=1.0, - bias_min_rms=1.0e-05, - bias_max_rms=5.0, + weight_min_scale=0.005, + weight_max_scale=1.0, + bias_min_scale=1.0e-05, + bias_max_scale=5.0, size_update_period=4, clipping_update_period=100, debug_interval=0, @@ -398,10 +425,10 @@ def __init__( scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, - weight_min_rms=weight_min_rms, - bias_max_rms=bias_max_rms, - bias_min_rms=bias_min_rms, - weight_max_rms=weight_max_rms, + weight_min_scale=weight_min_scale, + bias_max_scale=bias_max_scale, + bias_min_scale=bias_min_scale, + weight_max_scale=weight_max_scale, clipping_update_period=clipping_update_period, debug_interval=debug_interval, ) @@ -863,10 +890,10 @@ def __init__( scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, - weight_min_rms=0.005, - weight_max_rms=1.0, - bias_min_rms=1.0e-05, - bias_max_rms=5.0, + weight_min_scale=0.005, + weight_max_scale=1.0, + bias_min_scale=1.0e-05, + bias_max_scale=5.0, debug_interval=0, ): @@ -879,10 +906,10 @@ def __init__( scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, - weight_min_rms=weight_min_rms, - bias_max_rms=bias_max_rms, - bias_min_rms=bias_min_rms, - weight_max_rms=weight_max_rms, + weight_min_scale=weight_min_scale, + bias_max_scale=bias_max_scale, + bias_min_scale=bias_min_scale, + weight_max_scale=weight_max_scale, debug_interval=debug_interval, ) super().__init__(params, defaults) @@ -1421,7 +1448,7 @@ def step(self, closure=None): def _test_transformed_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear, OrthogonalLinear + from scaling import OrthogonalLinear E = 100 B = 4 @@ -1438,9 +1465,9 @@ def _test_transformed_adam(hidden_dim: int): input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - for iter in [0, 1, 2]: + for test in [0, 1, 2]: fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear + Linear = torch.nn.Linear m = torch.nn.Sequential( Linear(E, hidden_dim), @@ -1462,14 +1489,14 @@ def _test_transformed_adam(hidden_dim: int): for _ in range(20) ] - if iter == 0: + if test == 0: optim = SimpleTransformedAdam(m.parameters(), lr=0.06, eps=1.0e-20) - elif iter == 1: + elif test == 1: optim = TransformedAdam(m.named_parameters(), lr=0.06, clipping_scale=2.0, eps=1.0e-20) - elif iter == 2: + elif test == 2: optim = Eve(m.parameters(), lr=0.003) else: - assert "unknown iter", iter + assert "unknown test", test scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) @@ -1477,7 +1504,7 @@ def _test_transformed_adam(hidden_dim: int): avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: + # if epoch == 100 and test in [2,3]: # optim.reset_speedup() # check it doesn't crash. # if epoch == 130: @@ -1505,7 +1532,7 @@ def _test_transformed_adam(hidden_dim: int): lr = scheduler.get_last_lr()[0] logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + f"Test {test}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" ) loss.log().backward() optim.step() @@ -1515,7 +1542,7 @@ def _test_transformed_adam(hidden_dim: int): # diagnostic.print_diagnostics() stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") + logging.info(f"Test={test}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") # logging.info("state dict = ", scheduler.state_dict()) @@ -1525,8 +1552,8 @@ def _test_transformed_adam(hidden_dim: int): def _test_transform_params(): # caution: this has occasional errors. - group = { "bias_min_rms": 0.001, "weight_min_rms": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, - "weight_max_rms": 20.0, "bias_max_rms": 20.0 } + group = { "bias_min_scale": 0.001, "weight_min_scale": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, + "weight_max_scale": 20.0, "bias_max_scale": 20.0 } for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: for shape in [ (1, 1), (2, 1), (2, 2), (2, 3, 4), (3, 10, 20), (4,) ]: p = scale * torch.randn(*shape) From cbf5efd88890c3de522cc3279f6305f3cf3aabba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Sep 2025 03:53:59 +0800 Subject: [PATCH 0552/1191] Take zipformer.py from 1218conv --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 65f115e4ca..4967920b62 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -856,7 +856,7 @@ def __init__( self.residual_scales = nn.Parameter( torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), - (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim)], + (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], dim=0)) self.copy_bypass = Identity() @@ -909,7 +909,7 @@ def forward( src_orig = src residual_scale = limit_param_value(self.residual_scales[0], - min=-1.0, max=0.0) + min=-1.0, max=-0.5) src_with_bypass = residual_scale * src for i, mod in enumerate(self.layers): @@ -922,7 +922,7 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.0 if i + 1 < num_layers else 0.05, + min=0.0 if i + 1 < num_layers else 0.1, max=1.0) src_with_bypass = src_with_bypass + residual_scale * src From 0d272be9240c06248afb856dec437cba6ab66670 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Sep 2025 07:42:47 +0800 Subject: [PATCH 0553/1191] Make cutoff for num_conv_modules==2 be downsampling_factor<=1, not downsampling_factor<=2. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4967920b62..90ae5c2d54 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -168,7 +168,7 @@ def _to_tuple(x): feedforward_multiple=feedforward_multiple[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], - num_conv_modules=(2 if downsampling_factor[i] <= 2 else 1), + num_conv_modules=(2 if downsampling_factor[i] == 1 else 1), causal=causal, ) From 15fd78348c17d7da246db80f65b8a3ffa21454b8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 00:00:21 +0800 Subject: [PATCH 0554/1191] Make num_conv_modules always be 1, even when downsampling_factor==1. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 90ae5c2d54..3e6f2c0db7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -168,7 +168,7 @@ def _to_tuple(x): feedforward_multiple=feedforward_multiple[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], - num_conv_modules=(2 if downsampling_factor[i] == 1 else 1), + num_conv_modules=1, causal=causal, ) From 51bec4105915e6349980d98240f0566892422aa1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 02:26:42 +0800 Subject: [PATCH 0555/1191] Take some files from 1226 to allow parallel run. --- egs/librispeech/ASR/zapformer2/.gitignore | 1 + .../ASR/zapformer2/asr_datamodule.py | 454 ++++ .../ASR/zapformer2/attention_decoder.py | 1 + egs/librispeech/ASR/zapformer2/beam_search.py | 1 + egs/librispeech/ASR/zapformer2/ctc_decode.py | 1 + egs/librispeech/ASR/zapformer2/decode.py | 1089 +++++++++ .../ASR/zapformer2/decode_gigaspeech.py | 1 + .../ASR/zapformer2/decode_stream.py | 1 + egs/librispeech/ASR/zapformer2/decoder.py | 1 + .../ASR/zapformer2/encoder_interface.py | 1 + .../ASR/zapformer2/export-onnx-ctc.py | 1 + .../zapformer2/export-onnx-streaming-ctc.py | 1 + .../ASR/zapformer2/export-onnx-streaming.py | 1 + egs/librispeech/ASR/zapformer2/export-onnx.py | 1 + egs/librispeech/ASR/zapformer2/export.py | 1 + egs/librispeech/ASR/zapformer2/finetune.py | 1 + .../ASR/zapformer2/generate_averaged_model.py | 1 + .../ASR/zapformer2/jit_pretrained.py | 1 + .../ASR/zapformer2/jit_pretrained_ctc.py | 1 + .../zapformer2/jit_pretrained_streaming.py | 1 + egs/librispeech/ASR/zapformer2/joiner.py | 1 + .../ASR/zapformer2/label_smoothing.py | 1 + egs/librispeech/ASR/zapformer2/model.py | 630 +++++ egs/librispeech/ASR/zapformer2/my_profile.py | 1 + egs/librispeech/ASR/zapformer2/onnx_check.py | 1 + egs/librispeech/ASR/zapformer2/onnx_decode.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + .../zapformer2/onnx_pretrained-streaming.py | 1 + .../ASR/zapformer2/onnx_pretrained.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_H.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_HLG.py | 1 + .../onnx_pretrained_ctc_HLG_streaming.py | 1 + egs/librispeech/ASR/zapformer2/optim.py | 1 + egs/librispeech/ASR/zapformer2/pretrained.py | 1 + .../ASR/zapformer2/pretrained_ctc.py | 1 + .../relative_position_attention_bwd_k_2.py | 321 +++ .../relative_position_attention_bwd_pos_2.py | 321 +++ .../relative_position_attention_bwd_q_2.py | 332 +++ .../relative_position_attention_fwd_2.py | 302 +++ ...ive_position_attention_module_optimized.py | 118 + egs/librispeech/ASR/zapformer2/scaling.py | 1 + .../ASR/zapformer2/scaling_converter.py | 1 + .../ASR/zapformer2/speech_recognition.py | 229 ++ .../ASR/zapformer2/streaming_beam_search.py | 1 + .../ASR/zapformer2/streaming_decode.py | 1 + egs/librispeech/ASR/zapformer2/subsampling.py | 1 + .../ASR/zapformer2/test_scaling.py | 1 + .../ASR/zapformer2/test_subsampling.py | 1 + egs/librispeech/ASR/zapformer2/train.py | 1678 +++++++++++++ egs/librispeech/ASR/zapformer2/zipformer.py | 2066 +++++++++++++++++ 52 files changed, 7581 insertions(+) create mode 100644 egs/librispeech/ASR/zapformer2/.gitignore create mode 100755 egs/librispeech/ASR/zapformer2/asr_datamodule.py create mode 120000 egs/librispeech/ASR/zapformer2/attention_decoder.py create mode 120000 egs/librispeech/ASR/zapformer2/beam_search.py create mode 120000 egs/librispeech/ASR/zapformer2/ctc_decode.py create mode 100755 egs/librispeech/ASR/zapformer2/decode.py create mode 120000 egs/librispeech/ASR/zapformer2/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zapformer2/decode_stream.py create mode 120000 egs/librispeech/ASR/zapformer2/decoder.py create mode 120000 egs/librispeech/ASR/zapformer2/encoder_interface.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx.py create mode 120000 egs/librispeech/ASR/zapformer2/export.py create mode 120000 egs/librispeech/ASR/zapformer2/finetune.py create mode 120000 egs/librispeech/ASR/zapformer2/generate_averaged_model.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/joiner.py create mode 120000 egs/librispeech/ASR/zapformer2/label_smoothing.py create mode 100755 egs/librispeech/ASR/zapformer2/model.py create mode 120000 egs/librispeech/ASR/zapformer2/my_profile.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_check.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_decode.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/optim.py create mode 120000 egs/librispeech/ASR/zapformer2/pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/pretrained_ctc.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py create mode 120000 egs/librispeech/ASR/zapformer2/scaling.py create mode 120000 egs/librispeech/ASR/zapformer2/scaling_converter.py create mode 100755 egs/librispeech/ASR/zapformer2/speech_recognition.py create mode 120000 egs/librispeech/ASR/zapformer2/streaming_beam_search.py create mode 120000 egs/librispeech/ASR/zapformer2/streaming_decode.py create mode 120000 egs/librispeech/ASR/zapformer2/subsampling.py create mode 120000 egs/librispeech/ASR/zapformer2/test_scaling.py create mode 120000 egs/librispeech/ASR/zapformer2/test_subsampling.py create mode 100755 egs/librispeech/ASR/zapformer2/train.py create mode 100644 egs/librispeech/ASR/zapformer2/zipformer.py diff --git a/egs/librispeech/ASR/zapformer2/.gitignore b/egs/librispeech/ASR/zapformer2/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zapformer2/asr_datamodule.py b/egs/librispeech/ASR/zapformer2/asr_datamodule.py new file mode 100755 index 0000000000..4db6e101fb --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/asr_datamodule.py @@ -0,0 +1,454 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +# This K2SpeechRecognitionDataset is a modified version of one from +# lhotse.dataset, modified to, in training mode, to return a batch that has 3 +# different copies of the same data with the last two having different Musan +# augmentations and the first having none; and also include the key "num_copies" +# in the batch which would be 1 for the validation data (no Musan) and 3 for the +# training data with musan. +from speech_recognition import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer2/attention_decoder.py b/egs/librispeech/ASR/zapformer2/attention_decoder.py new file mode 120000 index 0000000000..830180a0cd --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/attention_decoder.py @@ -0,0 +1 @@ +../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/beam_search.py b/egs/librispeech/ASR/zapformer2/beam_search.py new file mode 120000 index 0000000000..8554e44ccf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/ctc_decode.py b/egs/librispeech/ASR/zapformer2/ctc_decode.py new file mode 120000 index 0000000000..a78e5c1df0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/ctc_decode.py @@ -0,0 +1 @@ +../zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode.py b/egs/librispeech/ASR/zapformer2/decode.py new file mode 100755 index 0000000000..221f01297b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode.py @@ -0,0 +1,1089 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py new file mode 120000 index 0000000000..63b0ef617b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py @@ -0,0 +1 @@ +../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode_stream.py b/egs/librispeech/ASR/zapformer2/decode_stream.py new file mode 120000 index 0000000000..4e59d04a12 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode_stream.py @@ -0,0 +1 @@ +../zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decoder.py b/egs/librispeech/ASR/zapformer2/decoder.py new file mode 120000 index 0000000000..cab465d2b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/encoder_interface.py b/egs/librispeech/ASR/zapformer2/encoder_interface.py new file mode 120000 index 0000000000..aa5d0217a8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py new file mode 120000 index 0000000000..dc14e93e75 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..3baa2b673c --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py new file mode 120000 index 0000000000..d18cb9a9a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx.py b/egs/librispeech/ASR/zapformer2/export-onnx.py new file mode 120000 index 0000000000..f343cf7027 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx.py @@ -0,0 +1 @@ +../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export.py b/egs/librispeech/ASR/zapformer2/export.py new file mode 120000 index 0000000000..1a126ab695 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export.py @@ -0,0 +1 @@ +../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/finetune.py b/egs/librispeech/ASR/zapformer2/finetune.py new file mode 120000 index 0000000000..0e9e7989b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/finetune.py @@ -0,0 +1 @@ +../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/generate_averaged_model.py b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py new file mode 120000 index 0000000000..b65513a058 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py @@ -0,0 +1 @@ +../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained.py b/egs/librispeech/ASR/zapformer2/jit_pretrained.py new file mode 120000 index 0000000000..5d45825206 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py new file mode 120000 index 0000000000..43aeb684bf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py new file mode 120000 index 0000000000..8e5e6f9812 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/joiner.py b/egs/librispeech/ASR/zapformer2/joiner.py new file mode 120000 index 0000000000..444cb5f150 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/label_smoothing.py b/egs/librispeech/ASR/zapformer2/label_smoothing.py new file mode 120000 index 0000000000..3690afff9d --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/label_smoothing.py @@ -0,0 +1 @@ +../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/model.py b/egs/librispeech/ASR/zapformer2/model.py new file mode 100755 index 0000000000..278e498032 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/model.py @@ -0,0 +1,630 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, convert_num_channels, PredictLoss +from icefall.utils import add_sos, make_pad_mask, time_warp + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.predict_loss = PredictLoss(encoder_dim) + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.1, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.1, + ) + + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + nn.LogSoftmax(dim=-1), + ) + + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + + self.reconstruction_proj = ScaledLinear( + encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) + + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) + + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + + src_key_padding_mask = make_pad_mask(x_lens) # (N, T) + specaug_mask = specaug_mask[:, ::2] + assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 + specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale, + sd_prob=0.0) + + predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + + return encoder_out, encoder_out_lens, predict_loss + + + def compute_predict_loss(self, + encoder_out: Tensor, + src_key_padding_mask: Optional[Tensor], + specaug_mask: Optional[Tensor]) -> Tensor: + if src_key_padding_mask is not None and specaug_mask is not None: + mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) + elif src_key_padding_mask is not None: + mask = src_key_padding_mask.t().logical_not() + elif specaug_mask is not None: + mask = specaug_mask.t().logical_not() + else: + mask = None + return self.predict_loss(encoder_out, mask) + + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). This activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + return ctc_loss + + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss, with consistency regularization loss if we are in training mode. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.amp.autocast('cuda', enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.amp.autocast('cuda', enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + spec_augment: Optional[nn.Module] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + num_copies: int = 1, + aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + spec_augment: + The SpecAugment instance, or similar/compatible object, that masks + log-mel features. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. Used only for + time-warping, if num_copies > 1. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if num_copies > 1, corresponds to training mode. + num_copies: + the number of copies of the same data that are in the batch, e.g. 1, 2 + or 3; affects CRCTC, spec-augment, etc. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + Returns: + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + device = x.device + + if num_copies > 1: + assert num_copies == 3 # for now. + # will do SpecAugment or similar. + assert spec_augment is not None and getattr(spec_augment, 'time_warp_factor', -1) < 0 + + (batch_size, seq_len, num_channels) = x.shape + B = batch_size // num_copies + x = x.reshape(num_copies, B, seq_len, num_channels) + + do_time_warp = True + if do_time_warp: + # Apply time warping. First append the copies on the channel + # dimension so all copies get the exact same time-warping. + x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + + assert supervision_segments is not None + with torch.amp.autocast('cuda', enabled=False): + x = time_warp( + x.to(torch.float), + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments[:B], + ) + x = x.reshape(B, seq_len, num_copies, num_channels) + x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + + # x_no_specaug is several repeats of the 1st copy of the data, which + # is the one not augmented with Musan. But it does have time + # warping and mel warping. + x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( + B * (num_copies - 1), seq_len, num_channels) + + + # Independently apply frequency masking and time masking to all but the first + # copy of the data. + x = spec_augment(x[1:].reshape(-1, seq_len, num_channels)) + + x_lens = x_lens[:B*(num_copies-1)] + y = y[:B*(num_copies-1)] + else: + x_no_specaug = x + + + # Compute encoder outputs + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale, + sd_prob=sd_prob) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + targets = y.values + if not self.training: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) + + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + + + def forward_reconstruction_loss(self, + log_mels: Tensor, + encoder_out: Tensor, + encoder_out_lens: Tensor): + """ + Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If + use_cr_ctc then we swap the first and second halves of the batch. + + Args: + log_mels: log-mel features of shape (batch_size, T, num_mels) + encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) + """ + batch_size = log_mels.shape[0] + num_mels = log_mels.shape[2] + + + def gauss_norm(x): + # normalize by gaussianizing on each dimension + values, indexes = x.sort(dim=1) # sort on seq dim + N = max(2, x.shape[1]) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[1], device=x.device, dtype=torch.float) + norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data + norm_rank = norm_rank.reshape(1, -1, 1) + norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) + x_norm = torch.empty_like(x) + x_norm.scatter_(dim=1, index=indexes, src=norm_rank) + return x_norm + + log_mels = gauss_norm(log_mels) + + pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) + T_embed = pred_mels.shape[1] + pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) + + excess_frames = log_mels.shape[1] - pred_mels.shape[1] + assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. + + T = pred_mels.shape[1] + offset = 3 # i found excess_frames = 5 one time. + log_mels = log_mels[:, offset:offset+T] + + lens = encoder_out_lens * 4 + pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions + assert pad_mask.shape == (batch_size, T) + pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position + # padd_mask: (batch_size, T, 1) + + + # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly + # helps to down-weight the effect of very silent silences. + #loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, + # reduction='none', beta=1.0) + # this way of applying the padding mask is not really ideal in terms of normalization, + # it will cause us to under-normalize a bit. + diff = log_mels * pad_mask - pred_mels * pad_mask + + loss = (diff ** 2) + + # removing the masking logic since we now use the no-specaug reference sequence. + ## masking. if it's different from the next item on both the frequency dim + ## and the time dim, it means we are in neither a time masked nor a frequency masked + ## position. + #mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), + # log_mels != torch.roll(log_mels, 1, dims=1)) + #loss = loss * mask.to(loss.dtype) + + loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. + return loss diff --git a/egs/librispeech/ASR/zapformer2/my_profile.py b/egs/librispeech/ASR/zapformer2/my_profile.py new file mode 120000 index 0000000000..76e48b756b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/my_profile.py @@ -0,0 +1 @@ +../zipformer/my_profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_check.py b/egs/librispeech/ASR/zapformer2/onnx_check.py new file mode 120000 index 0000000000..7293c70d46 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_check.py @@ -0,0 +1 @@ +../zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_decode.py b/egs/librispeech/ASR/zapformer2/onnx_decode.py new file mode 120000 index 0000000000..9e3faa5e01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_decode.py @@ -0,0 +1 @@ +../zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py new file mode 120000 index 0000000000..f8abb9daa5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py new file mode 120000 index 0000000000..11b846322e --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py new file mode 120000 index 0000000000..a085def837 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py new file mode 120000 index 0000000000..0c082a204f --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py new file mode 120000 index 0000000000..68102c7374 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py new file mode 120000 index 0000000000..8314b4efdf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py new file mode 120000 index 0000000000..7a637a1c01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py new file mode 120000 index 0000000000..a5b04b3f8b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/optim.py b/egs/librispeech/ASR/zapformer2/optim.py new file mode 120000 index 0000000000..207eecfcda --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained.py b/egs/librispeech/ASR/zapformer2/pretrained.py new file mode 120000 index 0000000000..70ad71ffc6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/pretrained.py @@ -0,0 +1 @@ +../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py new file mode 120000 index 0000000000..fb9bdf1fa2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py new file mode 100755 index 0000000000..aa85d1fff7 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_k_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in scores_grad + BLOCK_N: tl.constexpr, # block size in q + BLOCK_C: tl.constexpr, # block size for seq_q + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_k, BLOCK_M) + num_pid_n = tl.cdiv(channels, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(BLOCK_M == 1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), for k, seq_k + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,), for j, channel + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n_mask = offs_n[:, None] < channels + + # (BLOCK_C,), for i, seq_q + offs_c = tl.arange(0, BLOCK_C) + + q_base = q_ptr + batch * stride_qb + head * stride_qh + offs_n[:, None] * stride_qc + k_base = k_ptr + batch * stride_kb + head * stride_kh + pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc + scores_grad_base = ( + scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sk + ) + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + scores_grad_mask = (offs_m < seq_k) & (c_idx[None, :] < seq_q) + + # (BLOCK_N, BLOCK_C), or (J, I) + q_mask = offs_n_mask & (c_idx[None, :] < seq_q) + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + rel_idx = c_idx[None, :] - offs_m + max_seq_len - 1 + + # (BLOCK_M, BLOCK_N, BLOCK_C), or (K, J, I), or (BLOCK_N, BLOCK_C) + pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask + + scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sq + q_ptrs = q_base + c_idx[None, :] * stride_qs + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + # (BLOCK_N, BLOCK_C), or (J, I) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C), or (J, I) + pos_ptrs = pos_base + rel_idx * stride_ps + + pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) + + # scores_grad_chunk (1, BLOCK_C), or (K, I) + # q_chunk (BLOCK_N, BLOCK_C), or (J, I) + # pos_chunk (BLOCK_N, BLOCK_C), or (J, I) + qp = q_chunk * pos_chunk + + acc += tl.sum(scores_grad_chunk * qp, axis=1) + + k_ptrs = k_base + offs_m * stride_ks + offs_n * stride_kc + k_mask = (offs_m < seq_k) & (offs_n < channels) + tl.store(k_ptrs, acc, mask=k_mask) + + +def relative_position_attention_bwd_k(scores_grad, q, pos): + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + assert q.is_contiguous() + assert pos.is_contiguous() + + assert scores_grad.ndim == q.ndim == 4, (scores_grad.shape, q.shape) + + assert pos.ndim == 3, pos.shape + b, h, seq_q, seq_k = scores_grad.shape + + assert q.shape[0] == b, q.shape + assert q.shape[1] == h, q.shape + assert q.shape[2] == seq_q, q.shape + + c = q.shape[3] + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert scores_grad.device == q.device == pos.device, ( + scores_grad.device, + q.device, + pos.device, + ) + + k = torch.empty(b, h, seq_k, c, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_k, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_k_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return k + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_k(scores_grad, q, pos), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + pos_copy = pos.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("k.grad", k.grad.shape, k.grad.sum()) + + scores_grad = scores0.grad.clone() + k_grad = relative_position_attention_bwd_k(scores_grad, q_copy, pos_copy) + + print(k_grad.shape, k_grad.sum()) + print((k.grad - k_grad).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py new file mode 100755 index 0000000000..93d1f09dc3 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 16, + "BLOCK_C": 16, + "GROUP_SIZE_M": 4, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_pos_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in q + BLOCK_N: tl.constexpr, # block size in k + BLOCK_C: tl.constexpr, # block size for channel + GROUP_SIZE_M: tl.constexpr, # size for grouped block, not used +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_n = tl.cdiv(seq_k, BLOCK_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(BLOCK_M == 1) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,) + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, 1) + rel_idx = offs_m - offs_n[:, None] + max_seq_len - 1 + + q_base = q_ptr + batch * stride_qb + head * stride_qh + k_base = k_ptr + batch * stride_kb + head * stride_kh + pos_base = pos_ptr + head * stride_ph + + scores_grad_base = scores_grad_ptr + batch * stride_sb + head * stride_sh + scores_grad_ptrs = ( + scores_grad_base + offs_m * stride_sq + offs_n[:, None] * stride_sk + ) + + # (BLOCK_N, 1) + scores_grad_mask = (offs_m < seq_q) & (offs_n[:, None] < seq_k) + + # (BLOCK_N, 1) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (1, BLOCK_C) + q_mask = (offs_m < seq_q) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C), or (K, J) + k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + pos_mask = ( + (rel_idx >= 0) + & (rel_idx < 2 * max_seq_len - 1) + & (c_idx[None, :] < channels) + ) + + q_ptrs = q_base + offs_m * stride_qs + c_idx[None, :] * stride_qc + k_ptrs = k_base + offs_n[:, None] * stride_ks + c_idx[None, :] * stride_kc + + # (1, BLOCK_C) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs = pos_base + rel_idx * stride_ps + c_idx[None, :] * stride_pc + + # q_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # scores_grad_chunk (BLOCK_N, 1) + # + # pos_chunk: (BLOCK_N, BLOCK_C) + qk = q_chunk * k_chunk + pos_chunk = scores_grad_chunk * qk + + tl.atomic_add(pos_ptrs, pos_chunk, mask=pos_mask) + + +def relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len): + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + + assert q.is_contiguous() + assert k.is_contiguous() + + assert scores_grad.ndim == q.ndim == k.ndim == 4, ( + scores_grad.shape, + q.shape, + k.shape, + ) + b, h, seq_q, seq_k = scores_grad.shape + c = q.shape[3] + + assert k.shape[0] == b, k.shape + assert k.shape[1] == h, k.shape + assert k.shape[2] == seq_k, k.shape + assert k.shape[3] == c, k.shape + + assert q.shape[0] == b, q.shape + assert q.shape[1] == h, q.shape + assert q.shape[2] == seq_q, q.shape + + assert scores_grad.device == q.device == k.device, ( + scores_grad.device, + q.device, + k.device, + ) + + pos = torch.zeros(h, 2 * max_seq_len - 1, c, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_pos_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return pos + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + k_copy = k.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("pos.grad", pos.grad.shape, pos.grad.sum()) + + pos_grad = relative_position_attention_bwd_pos( + scores0.grad, q_copy, k_copy, max_seq_len + ) + + print(pos_grad.shape, pos_grad.sum()) + print((pos.grad - pos_grad).abs().max()) + + +def main(): + # test_benchmark() + test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py new file mode 100755 index 0000000000..5a9ececf0c --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_q_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in scores_grad + BLOCK_N: tl.constexpr, # block size in channels + BLOCK_C: tl.constexpr, # block size for seq_k + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_q, BLOCK_M) + num_pid_n = tl.cdiv(channels, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(BLOCK_M == 1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), we should always set BLOCK_M to 1 + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,) for channels + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,), for seq_k + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, 1) + offs_n_mask = offs_n[:, None] < channels + + q_base = q_ptr + batch * stride_qb + head * stride_qh + k_base = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_kc + pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc + scores_grad_base = ( + scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sq + ) + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, seq_k, BLOCK_C): + c_idx = c + offs_c + + # (1, BLOCK_C) + rel_idx = offs_m - c_idx[None, :] + max_seq_len - 1 + + # (1, BLOCK_C) + scores_grad_mask = (offs_m < seq_q) & (c_idx[None, :] < seq_k) + + # (BLOCK_N, BLOCK_C) + k_mask = offs_n_mask & (c_idx[None, :] < seq_k) + + # (BLOCK_N, BLOCK_C) + pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask + + scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sk + k_ptrs = k_base + c_idx[None, :] * stride_ks + + # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs = pos_base + rel_idx * stride_ps + + pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) + + # scores_grad_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # pos_chunk (BLOCK_N, BLOCK_C) + + # kp: (BLOCK_N, BLOCK_C) + kp = k_chunk * pos_chunk + + acc += tl.sum(scores_grad_chunk * kp, axis=1) + + q_ptrs = q_base + offs_m * stride_qs + offs_n * stride_qc + q_mask = (offs_m < seq_q) & (offs_n < channels) + tl.store(q_ptrs, acc, mask=q_mask) + + +def relative_position_attention_bwd_q(scores_grad, k, pos): + """ + Args: + scores_grad: (b, h, seq_q, seq_k) + k: (b, h, seq_k, channels) + pos: (h, 2*max_seq_len-1, channels) + Returns: + grad of q: (b, h, seq_q, channels) + """ + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + assert k.is_contiguous() + assert pos.is_contiguous() + + assert scores_grad.ndim == k.ndim == 4, (scores_grad.shape, k.shape) + assert pos.ndim == 3, pos.shape + b, h, seq_q, seq_k = scores_grad.shape + + c = k.shape[3] + + assert k.shape[0] == b, (k.shape, scores_grad.shape) + assert k.shape[1] == h, (k.shape, scores_grad.shape) + assert k.shape[2] == seq_k, (k.shape, scores_grad.shape) + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert scores_grad.device == k.device == pos.device, ( + scores_grad.device, + k.device, + pos.device, + ) + + q = torch.empty(b, h, seq_q, c, device=k.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_q_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return q + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + k = torch.randn(b, h, seq_k, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_q(scores_grad, k, pos), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + k_copy = k.clone() + pos_copy = pos.clone() + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("q.grad", q.grad.shape, q.grad.sum()) + + scores_grad = scores0.grad.clone() + q_grad = relative_position_attention_bwd_q(scores_grad, k_copy, pos_copy) + print(q_grad.shape, q_grad.sum()) + print((q.grad - q_grad).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py new file mode 100755 index 0000000000..e6ea552035 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_fwd_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in q + BLOCK_N: tl.constexpr, # block size in k + BLOCK_C: tl.constexpr, # block size for channel + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_q, BLOCK_M) + num_pid_n = tl.cdiv(seq_k, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), we should always set BLOCK_M to 1 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # (BLOCK_N,) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,) + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, ) + rel_idx = offs_m - offs_n + max_seq_len - 1 + + # (BLOCK_N, 1) + rel_idx_mask = (rel_idx[:, None] >= 0) & (rel_idx[:, None] < 2 * max_seq_len - 1) + + q_ptrs = q_ptr + batch * stride_qb + head * stride_qh + offs_m[:, None] * stride_qs + k_ptrs = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_ks + + pos_ptrs = pos_ptr + head * stride_ph + rel_idx[:, None] * stride_ps + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (BLOCK_M, BLOCK_C) + q_mask = (offs_m[:, None] < seq_q) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + pos_mask = rel_idx_mask & (c_idx[None, :] < channels) + + q_ptrs0 = q_ptrs + c_idx[None, :] * stride_qc + k_ptrs0 = k_ptrs + c_idx[None, :] * stride_kc + + # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) + q_chunk = tl.load(q_ptrs0, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs0, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs0 = pos_ptrs + c_idx[None, :] * stride_pc + + pos_chunk = tl.load(pos_ptrs0, mask=pos_mask, other=0.0) + + # q_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # pos_chunk (BLOCK_N, BLOCK_C) + + acc += tl.sum(q_chunk * (k_chunk * pos_chunk), axis=1) + + scores_ptrs = ( + scores_ptr + + batch * stride_sb + + head * stride_sh + + offs_m * stride_sq + + offs_n * stride_sk + ) + scores_mask = (offs_m < seq_q) & (offs_n < seq_k) + + tl.store(scores_ptrs, acc, mask=scores_mask) + + +def relative_position_attention_fwd(q, k, pos): + assert q.is_contiguous() + assert k.is_contiguous() + assert pos.is_contiguous() + + assert q.ndim == k.ndim == 4, (q.shape, k.shape) + assert pos.ndim == 3, pos.shape + b, h, seq_q, c = q.shape + assert k.shape[0] == b, k.shape + assert k.shape[1] == h, k.shape + assert k.shape[3] == c, k.shape + + seq_k = k.shape[2] + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert q.device == k.device == pos.device, ( + q.device, + k.device, + pos.device, + ) + + scores = torch.empty(b, h, seq_q, seq_k, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_fwd_kernel[grid]( + q, k, pos, scores, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3), + ) + # fmt: on + return scores + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton", "torch"], + line_names=["Triton", "Torch"], + styles=[("green", "-"), ("blue", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance with pos", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_fwd_torch(q, k, pos), + quantiles=quantiles, + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_fwd(q, k, pos), quantiles=quantiles + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 8 + seq_q = 400 + seq_k = 400 + c = 1024 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores1 = relative_position_attention_fwd(q, k, pos) + print(scores0.shape, scores0.sum()) + print(scores1.shape, scores1.sum()) + print((scores0 - scores1).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py new file mode 100755 index 0000000000..21640764ba --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +import torch + +from relative_position_attention_fwd_2 import ( + relative_position_attention_fwd, + relative_position_attention_fwd_torch, +) + +from relative_position_attention_bwd_q_2 import relative_position_attention_bwd_q +from relative_position_attention_bwd_k_2 import relative_position_attention_bwd_k +from relative_position_attention_bwd_pos_2 import relative_position_attention_bwd_pos + + +class RelativePositionAttentionFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, pos): + """ + Args: + q: (batch, head, seq_q, channel) + k: (batch, head, seq_k, channel) + pos: (head, 2*max_seq_len-1, channel) + Returns: + scores: (batch, head, seq_q, seq_k) + """ + ctx.save_for_backward(q, k, pos) + return relative_position_attention_fwd(q, k, pos) + + @staticmethod + def backward(ctx, scores_grad): + q, k, pos = ctx.saved_tensors + q_grad = None + k_grad = None + pos_grad = None + + if ctx.needs_input_grad[0]: + q_grad = relative_position_attention_bwd_q(scores_grad, k, pos) + + if ctx.needs_input_grad[1]: + k_grad = relative_position_attention_bwd_k(scores_grad, q, pos) + + if ctx.needs_input_grad[2]: + max_seq_len = (pos.shape[1] + 1) // 2 + pos_grad = relative_position_attention_bwd_pos( + scores_grad, q, k, max_seq_len + ) + + return q_grad, k_grad, pos_grad + + +class RelativePositionAttentionModule(torch.nn.Module): + def forward( + self, q: torch.Tensor, k: torch.Tensor, pos: torch.Tensor + ) -> torch.Tensor: + """ + Args: + q: (batch, head, seq_q, channel) + k: (batch, head, seq_k, channel) + pos: (head, 2*max_seq_len-1, channel) + Returns: + scores: (batch, head, seq_q, seq_k) + """ + return RelativePositionAttentionFunction.apply(q, k, pos) + + +def _test(): + torch.manual_seed(20250820) + device = torch.device("cuda", 0) + b = 4 + h = 2 + seq_q = 100 + seq_k = 100 + c = 300 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + k_copy = k.clone() + pos_copy = pos.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + + q_copy.requires_grad_(True) + k_copy.requires_grad_(True) + pos_copy.requires_grad_(True) + + scores1 = RelativePositionAttentionModule()(q_copy, k_copy, pos_copy) + + s1 = (scale * scores1).sum() + s1.backward() + + print((s0 - s1).max().abs()) + print((q.grad - q_copy.grad).max().abs()) + print((k.grad - k_copy.grad).max().abs()) + print((pos.grad - pos_copy.grad).max().abs()) + """ + tensor(0.0005, device='cuda:0', grad_fn=) + tensor(7.6294e-06, device='cuda:0') + tensor(5.7220e-06, device='cuda:0') + tensor(3.4332e-05, device='cuda:0') + """ + + +if __name__ == "__main__": + _test() + pass diff --git a/egs/librispeech/ASR/zapformer2/scaling.py b/egs/librispeech/ASR/zapformer2/scaling.py new file mode 120000 index 0000000000..58e4b0a0fe --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/scaling_converter.py b/egs/librispeech/ASR/zapformer2/scaling_converter.py new file mode 120000 index 0000000000..bc7c7b5e37 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/scaling_converter.py @@ -0,0 +1 @@ +../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/speech_recognition.py b/egs/librispeech/ASR/zapformer2/speech_recognition.py new file mode 100755 index 0000000000..dd069cf3da --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/speech_recognition.py @@ -0,0 +1,229 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer2/streaming_beam_search.py b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py new file mode 120000 index 0000000000..97e6e733f2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py @@ -0,0 +1 @@ +../zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/streaming_decode.py b/egs/librispeech/ASR/zapformer2/streaming_decode.py new file mode 120000 index 0000000000..e31da07d01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/streaming_decode.py @@ -0,0 +1 @@ +../zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/subsampling.py b/egs/librispeech/ASR/zapformer2/subsampling.py new file mode 120000 index 0000000000..d178adc2e5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_scaling.py b/egs/librispeech/ASR/zapformer2/test_scaling.py new file mode 120000 index 0000000000..b776da79a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/test_scaling.py @@ -0,0 +1 @@ +../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_subsampling.py b/egs/librispeech/ASR/zapformer2/test_subsampling.py new file mode 120000 index 0000000000..2925ea3c51 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/test_subsampling.py @@ -0,0 +1 @@ +../zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/train.py b/egs/librispeech/ASR/zapformer2/train.py new file mode 100755 index 0000000000..4294e139f6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/train.py @@ -0,0 +1,1678 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Sched3, TransformedAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def get_adjusted_lr_batches(params: AttributeDict) -> float: + # returns an adjusted form of the "lr_batches" parameter used to set the learning + # rate in the Sched3 scheduler. + # We want the final LR to be based on the geometric mean of "how much data we + # have seen" and "how many batches we have seen". + # an easier way to look at it is this: the formula for learning rate depends + # on (cur_batch / lr_batches). if we write this as: + # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches + # then the numerator is a geometric mean of "how many batches we have seen" + # and "how much data we have seen". We can achieve this by setting + # lr_batches = params.lr_batches * (duration_ratio ** -0.5). + duration_ratio = (params.max_duration * params.world_size) / params.ref_duration + lr_batches = params.lr_batches * (duration_ratio ** -0.5) + logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") + return lr_batches + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="3,5,6,6,6,5", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--base-dim", + type=int, + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." + ) + + parser.add_argument( + "--embed-multiple", + type=int, + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", + ) + + parser.add_argument( + "--feedforward-multiple", + type=str, + default="3,3,3,3,3,3", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-multiple", + type=str, + default="4,6,9,12,9,6", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-multiple", + type=int, + default=8, + help="Factor by which embedding dimension in the decoder model is larger than base-dim.", + ) + + parser.add_argument( + "--joiner-multiple", + type=int, + default=8, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-multiple", + type=int, + default=8, + help="""Factor by which attention decoder dim is larger than base-dim""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-multiple", + type=int, + default=8, + help="""Determines attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-multiple", + type=int, + default=4, + help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=True, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--debug-interval", + type=int, + default=10, + help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for + finding parts of the network that are diverging or not well trained. + """ + ) + + parser.add_argument( + "--dump-debug-interval", + type=int, + default=0, + help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they + are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file + opened, seeked-in and closed for each scalar that is written). + """ + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=17500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--aux-loss-scale", + type=float, + default=0.05, + help="Scale on auxiliary losses that are defined in the model, such " + "as cosine loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--reconstruction-loss-scale", + type=float, + default=0.005, + help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", + ) + + parser.add_argument( + "--predict-loss-scale", + type=float, + default=0.01, + help="Prediction of random k-means after widest zipformer layer" + ) + + parser.add_argument( + "--stochastic-depth-prob", + type=float, + default=0.1, + help="Probability of using a randomly chosen stack output during training, instead of " + "final output." + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on pruned loss (for transducer) and the reconstruction and prediction + losses. Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=lookup(params, "embed_dim"), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + input_dim=lookup(params, "embed_dim"), + output_downsampling_factor=2, + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + cnn_module_kernel=lookup(params, "cnn_module_kernel"), + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero + causal=params.causal, + chunk_size=lookup(params, "chunk_size"), + left_context_frames=lookup(params, "left_context_frames"), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "decoder_dim"), + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + joiner = Joiner( + encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, + decoder_dim=lookup(params, "decoder_dim"), + joiner_dim=lookup(params, "joiner_dim"), + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "attention_decoder_dim"), + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=lookup(params, "attention_decoder_attention_dim"), + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), + memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + output_downsampling_factor = 2 + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), + decoder_dim=lookup(params, "decoder_dim"), + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + spec_augment: Optional[nn.Module] = None, + aux_loss_scale: float = 0.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + spec_augment: + The nn.Module instance (or similar object), used for training + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + num_copies = batch["num_copies"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + if num_copies > 1: + assert model.training + # will need the following for time-warping in nn.Module. + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + spec_augment = None # disable spec-aug + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=80, # for specaug + num_copies=num_copies, + aux_loss_scale=aux_loss_scale, + sd_prob=(params.stochastic_depth_prob if is_training else 0.0), + ) + + loss = 0.0 + + adjusted_batch_count = params.batch_idx_train + warm_step = params.warm_step + def warmup_schedule(scale, initial_factor): + # geometric warmup schedules. + warmup_factor = (1. if adjusted_batch_count >= warm_step else + initial_factor + (adjusted_batch_count / warm_step) * (1 - initial_factor)) + return scale * warmup_factor + + if params.use_transducer: + simple_loss_scale = params.simple_loss_scale + pruned_loss_scale = warmup_schedule(1.0, 0.05) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if num_copies > 1: + loss += params.cr_loss_scale * cr_loss + + reconstruction_loss_scale = params.reconstruction_loss_scale + + loss += reconstruction_loss_scale * reconstruction_loss + + if num_copies > 1: + loss += params.predict_loss_scale * predict_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + nframes = (feature_lens // params.subsampling_factor).sum().item() + if num_copies > 1: + nframes = nframes * (num_copies - 1) / num_copies # omit 1st copy + info["frames"] = nframes + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if num_copies > 1: + info["cr_loss"] = cr_loss.detach().cpu().item() + if num_copies > 1: + info["predict_loss"] = predict_loss.detach().cpu().item() + info["recon_loss"] = reconstruction_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + spec_augment: Optional[nn.Module] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + spec_augment: + The SpecAugment or similar instance used for CR-CTC. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def get_scaler_scale(): + if params.use_autocast and scaler._scale is not None: + return scaler._scale.item() + else: + return 1.0 + + def save_bad_model(suffix: str = ""): + if params.debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = get_scaler_scale() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = get_scaler_scale() + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + optimizer.write_debug_info(summary_writer=tb_writer) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + + assert params.use_ctc # for now, require CTC, we may remove this requirement later. + + spec_augment = ExpAugment() + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + debug_interval=params.debug_interval, + ) + + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and False: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + spec_augment=spec_augment, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[nn.Module] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer2/zipformer.py b/egs/librispeech/ASR/zapformer2/zipformer.py new file mode 100644 index 0000000000..f5e1afe779 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/zipformer.py @@ -0,0 +1,2066 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union +from relative_position_attention_module_optimized import RelativePositionAttentionFunction +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + OrthogonalLinear, + SimpleOrthogonalLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + ScaleLimiter, + ActivationDropoutAndLinear, + ExpNorm, + ChunkCausalDepthwiseConv1d, + CosineSimilarityLoss, + MinProductLoss, + MaxProductLoss, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, + with_loss, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + + dropout (float): dropout rate + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + def __init__( + self, + input_dim: int, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + query_head_dim: Union[int, Tuple[int]] = 24, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_multiple: Union[int, Tuple[int]] = 4, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + dropout: FloatLike = None, # see code below for default + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_multiple = _to_tuple(feedforward_multiple) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + # each one will be Zipformer2Encoder or OrthogonalDownsample or OrthogonalUpsample + encoders = [] + + num_encoders = len(downsampling_factor) + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + + + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_multiple=feedforward_multiple[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + num_conv_modules=(2 if downsampling_factor[i] == 1 else 1), + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + head_dim=query_head_dim[i], + dim=downsampling_factor[i]*input_dim, + out_proj=False, # (downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + sd_prob: + Stochastic-depth prob: with this probability we replace the final output + with the output of a randomly chosen stack (including the 'zero stack' which + means the original input x). Each stack except the 'zero stack' has a + separate output projection for stochastic depth, that only sees the + "non-bypass part", i.e. its encoder stack without the residual. + Returns: + Return (embeddings_lengths), where: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + chunk_size, left_context_chunks = self.get_chunk_info() + orig_seq_len = x.shape[0] + + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), + dim=0) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) + + num_stacks = len(self.downsampling_factor) + + x_sd = x + + def randomly_choose_seqs(x, this_x, prob: float): + batch_size = x.shape[1] + do_replace = (torch.rand(1, batch_size, 1, device=x.device) < prob).expand_as(x) + return torch.where(do_replace, this_x, x) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = downsample_by(x, ds) + T = x.shape[0] + x, this_x_sd = module( + x, + chunk_size=chunk_size, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) + ) + x = upsample_by(x, ds) + if sd_prob: + x_sd = randomly_choose_seqs(x_sd, upsample_by(this_x_sd, ds), 1. / (2. + i)) + + + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + od = self.output_downsampling_factor + x = downsample_by(x, od) + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + if sd_prob: + x_sd = downsample_by(x_sd, od) + x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + x = randomly_choose_seqs(x, x_sd, sd_prob) + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + new_states += new_layer_states + + x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. + + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def get_max_similarity(rank: int, power: float): + """ + This returns a value for the "max_similarity" argument of CosineSimilarityLoss. + the max_similarity is an upper limit we impose on the mean value of (x_i . x_j) + if i != j are two different sequence-position indexes and x_i and x_j are + activation vectors normalized to have unit length. + + rank: the dimension of the space, usually this is the num_channels, but if + we have just up-projected from a bottleneck, it would be the bottleneck + dimension. + power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean + we enforce the vector dimensions to be completely independent like Gaussian noise + (don't do this); if we set power=0.0 it would be equivalent to not having + the CosineSimilarityLoss at all. + + The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal + variable. If x consists of independent Gaussian noise of dimension D, with + variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" + would be close to a no-op for large D), then (x_i . x_j) would be distributed as + a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) + would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value + between this and 1, as a kind of heuristic limit on this max_similarity. + """ + return (0.7978845608 / (rank ** 0.5)) ** power + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +def pad_mask(mask: Optional[Tensor], seq_len: int): + # mask: (batch_size, old_seq_len) + # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). + if mask is None: + return None + (batch_size, old_seq_len) = mask.shape + pad = seq_len - old_seq_len + if pad == 0: + return mask + else: + return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), + dim=1) + + +def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) + return x + +def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + return x + + +def get_dct_matrix(N): + """ + Generates an orthonormal DCT-II matrix for a given size N. + Args: + N (int): The size of the square matrix. + Returns: + torch.Tensor: The N x N orthonormal DCT-II matrix. + """ + # Create the base matrix with dimensions (N, N) + mat = torch.zeros(N, N) + # Create a tensor for the indices k (rows) and n (columns) + k = torch.arange(N).unsqueeze(1) + n = torch.arange(N).unsqueeze(0) + # Fill the matrix using the DCT-II formula + mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) + # Adjust the first row (k=0) with a special normalization factor + mat[0] *= (2 ** -0.5) + return mat + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_multiple: determines the hidden dimension of the feedforward module + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + value_head_dim: int, + feedforward_multiple: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + num_conv_modules: int = 2, + causal: bool = False, + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + self.name = None # will be set from training loop + + self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + num_heads=2 * num_heads, + query_head_dim=query_head_dim, + dropout=0.0, + ) + + self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] + + feedforward_dim = embed_dim * feedforward_multiple + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + + if num_conv_modules >= 2: + self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + if num_conv_modules >= 1: + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + + self.scale_limiter = ScaleLimiter(max_var=2.0) + + self.norm = ExpNorm(embed_dim) + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, head_dim) or (batch_size, 2*seq_len-1, head_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale, + ) + num_heads = attn_weights.shape[0] // 2 # num heads per self_attn module + attn_weights1 = attn_weights[:num_heads] + attn_weights2 = attn_weights[num_heads//2:-num_heads//2] + attn_weights3 = attn_weights[num_heads:] + + src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + if hasattr(self, 'conv_module1'): + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + if hasattr(self, 'conv_module2'): + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) + offset = (src - src_orig) * residual_scale + src = src_orig + offset + + src = with_loss(src, + self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + None) + + src = self.scale_limiter(src) + + src = self.norm(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.residual(src_orig, src) + + src = self.norm(src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). + pos_dim: the dimension for the relative positional encoding +dropout: + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + + + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dim: int, + head_dim: int, + out_proj: bool, + ) -> None: + super().__init__() + + # self.downsample will also reverse the downsampling operation for us afterward. + self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) + self.proj.lr_scale = 0.75 + + self.encoder_pos = CompactRelPositionalEncoding( + head_dim, dropout_rate=0.0, length_factor=1.0 + ) + self.name = None + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.residual_scales = nn.Parameter( + torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), + (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], + dim=0)) + + self.copy_bypass = Identity() + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + + # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear + # module. + if out_proj: + self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) + self.out_proj.lr_scale = 0.75 + + # stochastic-depth proj. + self.sd_proj = nn.Linear(encoder_layer.embed_dim, dim) + + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + (out, out_sd), both of the same shape as src, + where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. + """ + pos_emb = self.encoder_pos(src) + + src_orig_fulldim = src + + src = self.proj(src) # project to layer dim. + + num_layers = len(self.layers) + src_orig = src + + residual_scale = limit_param_value(self.residual_scales[0], + min=-1.0, max=-0.5) + src_with_bypass = residual_scale * src + + for i, mod in enumerate(self.layers): + src = mod( + src, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, + ) + residual_scale = limit_param_value(self.residual_scales[i + 1], + min=0.0 if i + 1 < num_layers else 0.1, + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src + + + offset = src_with_bypass + + src = src_orig_fulldim + self.proj(offset, transpose=True) + # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, + # because of some identities involving orthogonal matrices. + + if aux_loss_scale: + src = with_loss(src, + self.offset_cosine_loss(offset.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask) + + self.cosine_loss(src.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), + None) + + src_sd = self.sd_proj(offset) + + if hasattr(self, 'out_proj'): + src = self.out_proj(src) + + return src, src_sd + + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + num_channels = src.shape[-1] + layer_dim = self.layers[0].embed_dim + if num_channels > layer_dim: + src, bypass = src[..., :layer_dim], src[..., layer_dim:] + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + src, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + src, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + if num_channels > layer_dim: + src = torch.cat((src, bypass), dim=-1) + + return src, new_states + + +class ResidualModule(nn.Module): + """ + An nn.Module that implements a learnable residual scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + function_scale_min: FloatLike = 0.1, + ): + super().__init__() + self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.function_scale_min = copy.deepcopy(function_scale_min) + + + def _get_scales(self): + function_scale = self.function_scale + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + function_scale = limit_param_value( + function_scale, min=float(self.function_scale_min), max=1.0, + ) + residual_scale = 1.0 - function_scale + return residual_scale, function_scale + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + residual_scale, function_scale = self._get_scales() + return residual_scale * src_orig + function_scale * src + + +class OrthogonalDownsample(torch.nn.Module): + """ + Downsamples on sequence axis by appending sequence-positions together, + and then optionally projects by an orthogonal matrix + + + +. Projection is initialized + in a special way and enforced to be orthogonal. + + Args: + channels: the number of input channels; the num output channels will be twice this + proj_dim: the number of channels, after combining 2 frames by interpolating their channels + as [ a b a b, .. ] that will actually be projected; the rest are just copied. + proj_dim=2 * channels would mean all channels are projected in a learned way + causal: True for causal systems, only affects error messages as requires even + input num frames. + """ + def __init__( + self, channels: int, proj_dim: int, causal: bool = False, + ): + super().__init__() + assert proj_dim <= channels * 2 + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + self.causal = causal + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + + if seq_len % 2 == 1: + if torch.jit.is_tracing(): + assert ( + not self.causal + ), f"pad should be zero for exporting streaming models. Given {pad}" + src = torch.cat((src, src[-1:]), dim=0) + seq_len += 1 + + # the following will place each 2 frames of a particular channel right after + # each other as if they were two different channels. + src = torch.stack((src[0::2], src[1::2]), dim=-1) + src = src.reshape(seq_len // 2, batch_size, in_channels * 2) + proj_channels = self.proj.weight.shape[0] + if proj_channels < in_channels * 2: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + return src + +class OrthogonalUpsample(torch.nn.Module): + """ + A very simple form of upsampling with an orthogonal matrix. + + proj_dim: the number of channels that will actually be projected; the rest are just copied. + proj_dim=channels would mean all channels are projected in a learned way + + """ + def __init__(self, channels: int, proj_dim: int): + super().__init__() + assert proj_dim <= channels + # gradually make smaller and then turn off the non-orthognality penalty. + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*2), batch_size, num_channels // 2) + """ + proj_channels = self.proj.weight.shape[0] + (seq_len, batch_size, in_channels) = src.shape + + if proj_channels < in_channels: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + + src = torch.stack((src[..., 0::2], src[..., 1::2]), + dim=1) # (seq_len, 2, batch_size, in_channels // 2) + src = src.reshape(seq_len * 2, batch_size, in_channels // 2) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125 * query_head_dim**-0.25 + ) + + + self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) + + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_query = Identity() + self.copy_key = Identity() + + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.6), (20000.0, 6.0), default=5.0)) + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, head_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.copy_key(k) + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + if aux_loss_scale: + k = with_loss(k, + self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads, + key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None), + None) + + + # time1 refers to target, time2 refers to source. + q = q.permute(1, 2, 0, 3) # (batch, head, time1, query_head_dim) + k = k.permute(1, 2, 0, 3) # (batch, head, time2, query_head_dim) + + if self.training: + k = with_loss(k, + self.qk_max_product(q.reshape(batch_size * num_heads, seq_len, query_head_dim), + k.reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads), + None) + + + attn_scores = RelativePositionAttentionFunction.apply(q.contiguous(), k.contiguous(), pos_emb.repeat(num_heads, 1, 1)) + + + assert attn_scores.shape == (batch_size, num_heads, seq_len, seq_len) + attn_scores = attn_scores.permute(1, 0, 2, 3) + # (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + + # HERE.. not finished streaming code. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, + bias=True, out_groups=num_heads) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + f = max(1.0, embed_dim / (num_heads * value_head_dim)) + + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) + + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + aux_loss_scale: float = 0.0, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + src_key_padding_mask: optional Tensor of shape (batch_size, src_seq_len); only + used for the cosine similarity loss, during training. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # x: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + if aux_loss_scale: + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), + aux_loss_scale, + mask=src_key_padding_mask), None) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwashL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.5, + ) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) + + + def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + x = self.in_proj(x) + x = self.out_proj(x) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwashR", + dropout_p=0.0, + initial_scale=0.05, + ) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.6)) + + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), + None) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + seq_len = 20 + # Just make sure the forward pass runs. + + input_dim = 50 + + c = Zipformer2( + input_dim=input_dim, + encoder_dim=(64, 96), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + + batch_size = 6 + seq_len = 21 + # Just make sure the forward pass runs. + f, lengths = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + sd_prob=0.1, + ) + f.sum().backward() + c.eval() + x_ = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + sd_prob=0.1, + ) + x_ # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From 4c919e06d000ee4baba9b2e09add94bbce46b88f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 02:27:01 +0800 Subject: [PATCH 0556/1191] Take some files from 1226 to allow parallel run. --- egs/librispeech/ASR/zapformer2/.gitignore | 1 + .../ASR/zapformer2/asr_datamodule.py | 454 ++++ .../ASR/zapformer2/attention_decoder.py | 1 + egs/librispeech/ASR/zapformer2/beam_search.py | 1 + egs/librispeech/ASR/zapformer2/ctc_decode.py | 1 + egs/librispeech/ASR/zapformer2/decode.py | 1089 +++++++++ .../ASR/zapformer2/decode_gigaspeech.py | 1 + .../ASR/zapformer2/decode_stream.py | 1 + egs/librispeech/ASR/zapformer2/decoder.py | 1 + .../ASR/zapformer2/encoder_interface.py | 1 + .../ASR/zapformer2/export-onnx-ctc.py | 1 + .../zapformer2/export-onnx-streaming-ctc.py | 1 + .../ASR/zapformer2/export-onnx-streaming.py | 1 + egs/librispeech/ASR/zapformer2/export-onnx.py | 1 + egs/librispeech/ASR/zapformer2/export.py | 1 + egs/librispeech/ASR/zapformer2/finetune.py | 1 + .../ASR/zapformer2/generate_averaged_model.py | 1 + .../ASR/zapformer2/jit_pretrained.py | 1 + .../ASR/zapformer2/jit_pretrained_ctc.py | 1 + .../zapformer2/jit_pretrained_streaming.py | 1 + egs/librispeech/ASR/zapformer2/joiner.py | 1 + .../ASR/zapformer2/label_smoothing.py | 1 + egs/librispeech/ASR/zapformer2/model.py | 630 +++++ egs/librispeech/ASR/zapformer2/my_profile.py | 1 + egs/librispeech/ASR/zapformer2/onnx_check.py | 1 + egs/librispeech/ASR/zapformer2/onnx_decode.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + .../zapformer2/onnx_pretrained-streaming.py | 1 + .../ASR/zapformer2/onnx_pretrained.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_H.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_HLG.py | 1 + .../onnx_pretrained_ctc_HLG_streaming.py | 1 + egs/librispeech/ASR/zapformer2/optim.py | 1 + egs/librispeech/ASR/zapformer2/pretrained.py | 1 + .../ASR/zapformer2/pretrained_ctc.py | 1 + .../relative_position_attention_bwd_k_2.py | 321 +++ .../relative_position_attention_bwd_pos_2.py | 321 +++ .../relative_position_attention_bwd_q_2.py | 332 +++ .../relative_position_attention_fwd_2.py | 302 +++ ...ive_position_attention_module_optimized.py | 118 + egs/librispeech/ASR/zapformer2/scaling.py | 1 + .../ASR/zapformer2/scaling_converter.py | 1 + .../ASR/zapformer2/speech_recognition.py | 229 ++ .../ASR/zapformer2/streaming_beam_search.py | 1 + .../ASR/zapformer2/streaming_decode.py | 1 + egs/librispeech/ASR/zapformer2/subsampling.py | 1 + .../ASR/zapformer2/test_scaling.py | 1 + .../ASR/zapformer2/test_subsampling.py | 1 + egs/librispeech/ASR/zapformer2/train.py | 1678 +++++++++++++ egs/librispeech/ASR/zapformer2/zipformer.py | 2066 +++++++++++++++++ 52 files changed, 7581 insertions(+) create mode 100644 egs/librispeech/ASR/zapformer2/.gitignore create mode 100755 egs/librispeech/ASR/zapformer2/asr_datamodule.py create mode 120000 egs/librispeech/ASR/zapformer2/attention_decoder.py create mode 120000 egs/librispeech/ASR/zapformer2/beam_search.py create mode 120000 egs/librispeech/ASR/zapformer2/ctc_decode.py create mode 100755 egs/librispeech/ASR/zapformer2/decode.py create mode 120000 egs/librispeech/ASR/zapformer2/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zapformer2/decode_stream.py create mode 120000 egs/librispeech/ASR/zapformer2/decoder.py create mode 120000 egs/librispeech/ASR/zapformer2/encoder_interface.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx.py create mode 120000 egs/librispeech/ASR/zapformer2/export.py create mode 120000 egs/librispeech/ASR/zapformer2/finetune.py create mode 120000 egs/librispeech/ASR/zapformer2/generate_averaged_model.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/joiner.py create mode 120000 egs/librispeech/ASR/zapformer2/label_smoothing.py create mode 100755 egs/librispeech/ASR/zapformer2/model.py create mode 120000 egs/librispeech/ASR/zapformer2/my_profile.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_check.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_decode.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/optim.py create mode 120000 egs/librispeech/ASR/zapformer2/pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/pretrained_ctc.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py create mode 120000 egs/librispeech/ASR/zapformer2/scaling.py create mode 120000 egs/librispeech/ASR/zapformer2/scaling_converter.py create mode 100755 egs/librispeech/ASR/zapformer2/speech_recognition.py create mode 120000 egs/librispeech/ASR/zapformer2/streaming_beam_search.py create mode 120000 egs/librispeech/ASR/zapformer2/streaming_decode.py create mode 120000 egs/librispeech/ASR/zapformer2/subsampling.py create mode 120000 egs/librispeech/ASR/zapformer2/test_scaling.py create mode 120000 egs/librispeech/ASR/zapformer2/test_subsampling.py create mode 100755 egs/librispeech/ASR/zapformer2/train.py create mode 100644 egs/librispeech/ASR/zapformer2/zipformer.py diff --git a/egs/librispeech/ASR/zapformer2/.gitignore b/egs/librispeech/ASR/zapformer2/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zapformer2/asr_datamodule.py b/egs/librispeech/ASR/zapformer2/asr_datamodule.py new file mode 100755 index 0000000000..4db6e101fb --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/asr_datamodule.py @@ -0,0 +1,454 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +# This K2SpeechRecognitionDataset is a modified version of one from +# lhotse.dataset, modified to, in training mode, to return a batch that has 3 +# different copies of the same data with the last two having different Musan +# augmentations and the first having none; and also include the key "num_copies" +# in the batch which would be 1 for the validation data (no Musan) and 3 for the +# training data with musan. +from speech_recognition import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer2/attention_decoder.py b/egs/librispeech/ASR/zapformer2/attention_decoder.py new file mode 120000 index 0000000000..830180a0cd --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/attention_decoder.py @@ -0,0 +1 @@ +../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/beam_search.py b/egs/librispeech/ASR/zapformer2/beam_search.py new file mode 120000 index 0000000000..8554e44ccf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/ctc_decode.py b/egs/librispeech/ASR/zapformer2/ctc_decode.py new file mode 120000 index 0000000000..a78e5c1df0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/ctc_decode.py @@ -0,0 +1 @@ +../zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode.py b/egs/librispeech/ASR/zapformer2/decode.py new file mode 100755 index 0000000000..221f01297b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode.py @@ -0,0 +1,1089 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py new file mode 120000 index 0000000000..63b0ef617b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py @@ -0,0 +1 @@ +../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode_stream.py b/egs/librispeech/ASR/zapformer2/decode_stream.py new file mode 120000 index 0000000000..4e59d04a12 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode_stream.py @@ -0,0 +1 @@ +../zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decoder.py b/egs/librispeech/ASR/zapformer2/decoder.py new file mode 120000 index 0000000000..cab465d2b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/encoder_interface.py b/egs/librispeech/ASR/zapformer2/encoder_interface.py new file mode 120000 index 0000000000..aa5d0217a8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py new file mode 120000 index 0000000000..dc14e93e75 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..3baa2b673c --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py new file mode 120000 index 0000000000..d18cb9a9a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx.py b/egs/librispeech/ASR/zapformer2/export-onnx.py new file mode 120000 index 0000000000..f343cf7027 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx.py @@ -0,0 +1 @@ +../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export.py b/egs/librispeech/ASR/zapformer2/export.py new file mode 120000 index 0000000000..1a126ab695 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export.py @@ -0,0 +1 @@ +../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/finetune.py b/egs/librispeech/ASR/zapformer2/finetune.py new file mode 120000 index 0000000000..0e9e7989b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/finetune.py @@ -0,0 +1 @@ +../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/generate_averaged_model.py b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py new file mode 120000 index 0000000000..b65513a058 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py @@ -0,0 +1 @@ +../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained.py b/egs/librispeech/ASR/zapformer2/jit_pretrained.py new file mode 120000 index 0000000000..5d45825206 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py new file mode 120000 index 0000000000..43aeb684bf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py new file mode 120000 index 0000000000..8e5e6f9812 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/joiner.py b/egs/librispeech/ASR/zapformer2/joiner.py new file mode 120000 index 0000000000..444cb5f150 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/label_smoothing.py b/egs/librispeech/ASR/zapformer2/label_smoothing.py new file mode 120000 index 0000000000..3690afff9d --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/label_smoothing.py @@ -0,0 +1 @@ +../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/model.py b/egs/librispeech/ASR/zapformer2/model.py new file mode 100755 index 0000000000..278e498032 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/model.py @@ -0,0 +1,630 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, convert_num_channels, PredictLoss +from icefall.utils import add_sos, make_pad_mask, time_warp + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.predict_loss = PredictLoss(encoder_dim) + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.1, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.1, + ) + + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + nn.LogSoftmax(dim=-1), + ) + + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + + self.reconstruction_proj = ScaledLinear( + encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) + + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) + + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + + src_key_padding_mask = make_pad_mask(x_lens) # (N, T) + specaug_mask = specaug_mask[:, ::2] + assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 + specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale, + sd_prob=0.0) + + predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + + return encoder_out, encoder_out_lens, predict_loss + + + def compute_predict_loss(self, + encoder_out: Tensor, + src_key_padding_mask: Optional[Tensor], + specaug_mask: Optional[Tensor]) -> Tensor: + if src_key_padding_mask is not None and specaug_mask is not None: + mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) + elif src_key_padding_mask is not None: + mask = src_key_padding_mask.t().logical_not() + elif specaug_mask is not None: + mask = specaug_mask.t().logical_not() + else: + mask = None + return self.predict_loss(encoder_out, mask) + + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). This activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + return ctc_loss + + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss, with consistency regularization loss if we are in training mode. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.amp.autocast('cuda', enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.amp.autocast('cuda', enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + spec_augment: Optional[nn.Module] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + num_copies: int = 1, + aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + spec_augment: + The SpecAugment instance, or similar/compatible object, that masks + log-mel features. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. Used only for + time-warping, if num_copies > 1. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if num_copies > 1, corresponds to training mode. + num_copies: + the number of copies of the same data that are in the batch, e.g. 1, 2 + or 3; affects CRCTC, spec-augment, etc. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + Returns: + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + device = x.device + + if num_copies > 1: + assert num_copies == 3 # for now. + # will do SpecAugment or similar. + assert spec_augment is not None and getattr(spec_augment, 'time_warp_factor', -1) < 0 + + (batch_size, seq_len, num_channels) = x.shape + B = batch_size // num_copies + x = x.reshape(num_copies, B, seq_len, num_channels) + + do_time_warp = True + if do_time_warp: + # Apply time warping. First append the copies on the channel + # dimension so all copies get the exact same time-warping. + x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + + assert supervision_segments is not None + with torch.amp.autocast('cuda', enabled=False): + x = time_warp( + x.to(torch.float), + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments[:B], + ) + x = x.reshape(B, seq_len, num_copies, num_channels) + x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + + # x_no_specaug is several repeats of the 1st copy of the data, which + # is the one not augmented with Musan. But it does have time + # warping and mel warping. + x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( + B * (num_copies - 1), seq_len, num_channels) + + + # Independently apply frequency masking and time masking to all but the first + # copy of the data. + x = spec_augment(x[1:].reshape(-1, seq_len, num_channels)) + + x_lens = x_lens[:B*(num_copies-1)] + y = y[:B*(num_copies-1)] + else: + x_no_specaug = x + + + # Compute encoder outputs + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale, + sd_prob=sd_prob) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + targets = y.values + if not self.training: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) + + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + + + def forward_reconstruction_loss(self, + log_mels: Tensor, + encoder_out: Tensor, + encoder_out_lens: Tensor): + """ + Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If + use_cr_ctc then we swap the first and second halves of the batch. + + Args: + log_mels: log-mel features of shape (batch_size, T, num_mels) + encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) + """ + batch_size = log_mels.shape[0] + num_mels = log_mels.shape[2] + + + def gauss_norm(x): + # normalize by gaussianizing on each dimension + values, indexes = x.sort(dim=1) # sort on seq dim + N = max(2, x.shape[1]) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[1], device=x.device, dtype=torch.float) + norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data + norm_rank = norm_rank.reshape(1, -1, 1) + norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) + x_norm = torch.empty_like(x) + x_norm.scatter_(dim=1, index=indexes, src=norm_rank) + return x_norm + + log_mels = gauss_norm(log_mels) + + pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) + T_embed = pred_mels.shape[1] + pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) + + excess_frames = log_mels.shape[1] - pred_mels.shape[1] + assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. + + T = pred_mels.shape[1] + offset = 3 # i found excess_frames = 5 one time. + log_mels = log_mels[:, offset:offset+T] + + lens = encoder_out_lens * 4 + pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions + assert pad_mask.shape == (batch_size, T) + pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position + # padd_mask: (batch_size, T, 1) + + + # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly + # helps to down-weight the effect of very silent silences. + #loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, + # reduction='none', beta=1.0) + # this way of applying the padding mask is not really ideal in terms of normalization, + # it will cause us to under-normalize a bit. + diff = log_mels * pad_mask - pred_mels * pad_mask + + loss = (diff ** 2) + + # removing the masking logic since we now use the no-specaug reference sequence. + ## masking. if it's different from the next item on both the frequency dim + ## and the time dim, it means we are in neither a time masked nor a frequency masked + ## position. + #mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), + # log_mels != torch.roll(log_mels, 1, dims=1)) + #loss = loss * mask.to(loss.dtype) + + loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. + return loss diff --git a/egs/librispeech/ASR/zapformer2/my_profile.py b/egs/librispeech/ASR/zapformer2/my_profile.py new file mode 120000 index 0000000000..76e48b756b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/my_profile.py @@ -0,0 +1 @@ +../zipformer/my_profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_check.py b/egs/librispeech/ASR/zapformer2/onnx_check.py new file mode 120000 index 0000000000..7293c70d46 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_check.py @@ -0,0 +1 @@ +../zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_decode.py b/egs/librispeech/ASR/zapformer2/onnx_decode.py new file mode 120000 index 0000000000..9e3faa5e01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_decode.py @@ -0,0 +1 @@ +../zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py new file mode 120000 index 0000000000..f8abb9daa5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py new file mode 120000 index 0000000000..11b846322e --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py new file mode 120000 index 0000000000..a085def837 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py new file mode 120000 index 0000000000..0c082a204f --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py new file mode 120000 index 0000000000..68102c7374 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py new file mode 120000 index 0000000000..8314b4efdf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py new file mode 120000 index 0000000000..7a637a1c01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py new file mode 120000 index 0000000000..a5b04b3f8b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/optim.py b/egs/librispeech/ASR/zapformer2/optim.py new file mode 120000 index 0000000000..207eecfcda --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained.py b/egs/librispeech/ASR/zapformer2/pretrained.py new file mode 120000 index 0000000000..70ad71ffc6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/pretrained.py @@ -0,0 +1 @@ +../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py new file mode 120000 index 0000000000..fb9bdf1fa2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py new file mode 100755 index 0000000000..aa85d1fff7 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_k_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in scores_grad + BLOCK_N: tl.constexpr, # block size in q + BLOCK_C: tl.constexpr, # block size for seq_q + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_k, BLOCK_M) + num_pid_n = tl.cdiv(channels, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(BLOCK_M == 1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), for k, seq_k + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,), for j, channel + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n_mask = offs_n[:, None] < channels + + # (BLOCK_C,), for i, seq_q + offs_c = tl.arange(0, BLOCK_C) + + q_base = q_ptr + batch * stride_qb + head * stride_qh + offs_n[:, None] * stride_qc + k_base = k_ptr + batch * stride_kb + head * stride_kh + pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc + scores_grad_base = ( + scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sk + ) + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + scores_grad_mask = (offs_m < seq_k) & (c_idx[None, :] < seq_q) + + # (BLOCK_N, BLOCK_C), or (J, I) + q_mask = offs_n_mask & (c_idx[None, :] < seq_q) + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + rel_idx = c_idx[None, :] - offs_m + max_seq_len - 1 + + # (BLOCK_M, BLOCK_N, BLOCK_C), or (K, J, I), or (BLOCK_N, BLOCK_C) + pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask + + scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sq + q_ptrs = q_base + c_idx[None, :] * stride_qs + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + # (BLOCK_N, BLOCK_C), or (J, I) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C), or (J, I) + pos_ptrs = pos_base + rel_idx * stride_ps + + pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) + + # scores_grad_chunk (1, BLOCK_C), or (K, I) + # q_chunk (BLOCK_N, BLOCK_C), or (J, I) + # pos_chunk (BLOCK_N, BLOCK_C), or (J, I) + qp = q_chunk * pos_chunk + + acc += tl.sum(scores_grad_chunk * qp, axis=1) + + k_ptrs = k_base + offs_m * stride_ks + offs_n * stride_kc + k_mask = (offs_m < seq_k) & (offs_n < channels) + tl.store(k_ptrs, acc, mask=k_mask) + + +def relative_position_attention_bwd_k(scores_grad, q, pos): + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + assert q.is_contiguous() + assert pos.is_contiguous() + + assert scores_grad.ndim == q.ndim == 4, (scores_grad.shape, q.shape) + + assert pos.ndim == 3, pos.shape + b, h, seq_q, seq_k = scores_grad.shape + + assert q.shape[0] == b, q.shape + assert q.shape[1] == h, q.shape + assert q.shape[2] == seq_q, q.shape + + c = q.shape[3] + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert scores_grad.device == q.device == pos.device, ( + scores_grad.device, + q.device, + pos.device, + ) + + k = torch.empty(b, h, seq_k, c, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_k, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_k_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return k + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_k(scores_grad, q, pos), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + pos_copy = pos.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("k.grad", k.grad.shape, k.grad.sum()) + + scores_grad = scores0.grad.clone() + k_grad = relative_position_attention_bwd_k(scores_grad, q_copy, pos_copy) + + print(k_grad.shape, k_grad.sum()) + print((k.grad - k_grad).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py new file mode 100755 index 0000000000..93d1f09dc3 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 16, + "BLOCK_C": 16, + "GROUP_SIZE_M": 4, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_pos_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in q + BLOCK_N: tl.constexpr, # block size in k + BLOCK_C: tl.constexpr, # block size for channel + GROUP_SIZE_M: tl.constexpr, # size for grouped block, not used +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_n = tl.cdiv(seq_k, BLOCK_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(BLOCK_M == 1) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,) + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, 1) + rel_idx = offs_m - offs_n[:, None] + max_seq_len - 1 + + q_base = q_ptr + batch * stride_qb + head * stride_qh + k_base = k_ptr + batch * stride_kb + head * stride_kh + pos_base = pos_ptr + head * stride_ph + + scores_grad_base = scores_grad_ptr + batch * stride_sb + head * stride_sh + scores_grad_ptrs = ( + scores_grad_base + offs_m * stride_sq + offs_n[:, None] * stride_sk + ) + + # (BLOCK_N, 1) + scores_grad_mask = (offs_m < seq_q) & (offs_n[:, None] < seq_k) + + # (BLOCK_N, 1) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (1, BLOCK_C) + q_mask = (offs_m < seq_q) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C), or (K, J) + k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + pos_mask = ( + (rel_idx >= 0) + & (rel_idx < 2 * max_seq_len - 1) + & (c_idx[None, :] < channels) + ) + + q_ptrs = q_base + offs_m * stride_qs + c_idx[None, :] * stride_qc + k_ptrs = k_base + offs_n[:, None] * stride_ks + c_idx[None, :] * stride_kc + + # (1, BLOCK_C) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs = pos_base + rel_idx * stride_ps + c_idx[None, :] * stride_pc + + # q_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # scores_grad_chunk (BLOCK_N, 1) + # + # pos_chunk: (BLOCK_N, BLOCK_C) + qk = q_chunk * k_chunk + pos_chunk = scores_grad_chunk * qk + + tl.atomic_add(pos_ptrs, pos_chunk, mask=pos_mask) + + +def relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len): + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + + assert q.is_contiguous() + assert k.is_contiguous() + + assert scores_grad.ndim == q.ndim == k.ndim == 4, ( + scores_grad.shape, + q.shape, + k.shape, + ) + b, h, seq_q, seq_k = scores_grad.shape + c = q.shape[3] + + assert k.shape[0] == b, k.shape + assert k.shape[1] == h, k.shape + assert k.shape[2] == seq_k, k.shape + assert k.shape[3] == c, k.shape + + assert q.shape[0] == b, q.shape + assert q.shape[1] == h, q.shape + assert q.shape[2] == seq_q, q.shape + + assert scores_grad.device == q.device == k.device, ( + scores_grad.device, + q.device, + k.device, + ) + + pos = torch.zeros(h, 2 * max_seq_len - 1, c, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_pos_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return pos + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + k_copy = k.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("pos.grad", pos.grad.shape, pos.grad.sum()) + + pos_grad = relative_position_attention_bwd_pos( + scores0.grad, q_copy, k_copy, max_seq_len + ) + + print(pos_grad.shape, pos_grad.sum()) + print((pos.grad - pos_grad).abs().max()) + + +def main(): + # test_benchmark() + test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py new file mode 100755 index 0000000000..5a9ececf0c --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_q_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in scores_grad + BLOCK_N: tl.constexpr, # block size in channels + BLOCK_C: tl.constexpr, # block size for seq_k + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_q, BLOCK_M) + num_pid_n = tl.cdiv(channels, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(BLOCK_M == 1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), we should always set BLOCK_M to 1 + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,) for channels + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,), for seq_k + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, 1) + offs_n_mask = offs_n[:, None] < channels + + q_base = q_ptr + batch * stride_qb + head * stride_qh + k_base = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_kc + pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc + scores_grad_base = ( + scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sq + ) + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, seq_k, BLOCK_C): + c_idx = c + offs_c + + # (1, BLOCK_C) + rel_idx = offs_m - c_idx[None, :] + max_seq_len - 1 + + # (1, BLOCK_C) + scores_grad_mask = (offs_m < seq_q) & (c_idx[None, :] < seq_k) + + # (BLOCK_N, BLOCK_C) + k_mask = offs_n_mask & (c_idx[None, :] < seq_k) + + # (BLOCK_N, BLOCK_C) + pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask + + scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sk + k_ptrs = k_base + c_idx[None, :] * stride_ks + + # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs = pos_base + rel_idx * stride_ps + + pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) + + # scores_grad_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # pos_chunk (BLOCK_N, BLOCK_C) + + # kp: (BLOCK_N, BLOCK_C) + kp = k_chunk * pos_chunk + + acc += tl.sum(scores_grad_chunk * kp, axis=1) + + q_ptrs = q_base + offs_m * stride_qs + offs_n * stride_qc + q_mask = (offs_m < seq_q) & (offs_n < channels) + tl.store(q_ptrs, acc, mask=q_mask) + + +def relative_position_attention_bwd_q(scores_grad, k, pos): + """ + Args: + scores_grad: (b, h, seq_q, seq_k) + k: (b, h, seq_k, channels) + pos: (h, 2*max_seq_len-1, channels) + Returns: + grad of q: (b, h, seq_q, channels) + """ + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + assert k.is_contiguous() + assert pos.is_contiguous() + + assert scores_grad.ndim == k.ndim == 4, (scores_grad.shape, k.shape) + assert pos.ndim == 3, pos.shape + b, h, seq_q, seq_k = scores_grad.shape + + c = k.shape[3] + + assert k.shape[0] == b, (k.shape, scores_grad.shape) + assert k.shape[1] == h, (k.shape, scores_grad.shape) + assert k.shape[2] == seq_k, (k.shape, scores_grad.shape) + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert scores_grad.device == k.device == pos.device, ( + scores_grad.device, + k.device, + pos.device, + ) + + q = torch.empty(b, h, seq_q, c, device=k.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_q_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return q + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + k = torch.randn(b, h, seq_k, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_q(scores_grad, k, pos), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + k_copy = k.clone() + pos_copy = pos.clone() + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("q.grad", q.grad.shape, q.grad.sum()) + + scores_grad = scores0.grad.clone() + q_grad = relative_position_attention_bwd_q(scores_grad, k_copy, pos_copy) + print(q_grad.shape, q_grad.sum()) + print((q.grad - q_grad).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py new file mode 100755 index 0000000000..e6ea552035 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_fwd_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in q + BLOCK_N: tl.constexpr, # block size in k + BLOCK_C: tl.constexpr, # block size for channel + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_q, BLOCK_M) + num_pid_n = tl.cdiv(seq_k, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), we should always set BLOCK_M to 1 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # (BLOCK_N,) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,) + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, ) + rel_idx = offs_m - offs_n + max_seq_len - 1 + + # (BLOCK_N, 1) + rel_idx_mask = (rel_idx[:, None] >= 0) & (rel_idx[:, None] < 2 * max_seq_len - 1) + + q_ptrs = q_ptr + batch * stride_qb + head * stride_qh + offs_m[:, None] * stride_qs + k_ptrs = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_ks + + pos_ptrs = pos_ptr + head * stride_ph + rel_idx[:, None] * stride_ps + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (BLOCK_M, BLOCK_C) + q_mask = (offs_m[:, None] < seq_q) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + pos_mask = rel_idx_mask & (c_idx[None, :] < channels) + + q_ptrs0 = q_ptrs + c_idx[None, :] * stride_qc + k_ptrs0 = k_ptrs + c_idx[None, :] * stride_kc + + # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) + q_chunk = tl.load(q_ptrs0, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs0, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs0 = pos_ptrs + c_idx[None, :] * stride_pc + + pos_chunk = tl.load(pos_ptrs0, mask=pos_mask, other=0.0) + + # q_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # pos_chunk (BLOCK_N, BLOCK_C) + + acc += tl.sum(q_chunk * (k_chunk * pos_chunk), axis=1) + + scores_ptrs = ( + scores_ptr + + batch * stride_sb + + head * stride_sh + + offs_m * stride_sq + + offs_n * stride_sk + ) + scores_mask = (offs_m < seq_q) & (offs_n < seq_k) + + tl.store(scores_ptrs, acc, mask=scores_mask) + + +def relative_position_attention_fwd(q, k, pos): + assert q.is_contiguous() + assert k.is_contiguous() + assert pos.is_contiguous() + + assert q.ndim == k.ndim == 4, (q.shape, k.shape) + assert pos.ndim == 3, pos.shape + b, h, seq_q, c = q.shape + assert k.shape[0] == b, k.shape + assert k.shape[1] == h, k.shape + assert k.shape[3] == c, k.shape + + seq_k = k.shape[2] + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert q.device == k.device == pos.device, ( + q.device, + k.device, + pos.device, + ) + + scores = torch.empty(b, h, seq_q, seq_k, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_fwd_kernel[grid]( + q, k, pos, scores, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3), + ) + # fmt: on + return scores + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton", "torch"], + line_names=["Triton", "Torch"], + styles=[("green", "-"), ("blue", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance with pos", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_fwd_torch(q, k, pos), + quantiles=quantiles, + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_fwd(q, k, pos), quantiles=quantiles + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 8 + seq_q = 400 + seq_k = 400 + c = 1024 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores1 = relative_position_attention_fwd(q, k, pos) + print(scores0.shape, scores0.sum()) + print(scores1.shape, scores1.sum()) + print((scores0 - scores1).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py new file mode 100755 index 0000000000..21640764ba --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +import torch + +from relative_position_attention_fwd_2 import ( + relative_position_attention_fwd, + relative_position_attention_fwd_torch, +) + +from relative_position_attention_bwd_q_2 import relative_position_attention_bwd_q +from relative_position_attention_bwd_k_2 import relative_position_attention_bwd_k +from relative_position_attention_bwd_pos_2 import relative_position_attention_bwd_pos + + +class RelativePositionAttentionFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, pos): + """ + Args: + q: (batch, head, seq_q, channel) + k: (batch, head, seq_k, channel) + pos: (head, 2*max_seq_len-1, channel) + Returns: + scores: (batch, head, seq_q, seq_k) + """ + ctx.save_for_backward(q, k, pos) + return relative_position_attention_fwd(q, k, pos) + + @staticmethod + def backward(ctx, scores_grad): + q, k, pos = ctx.saved_tensors + q_grad = None + k_grad = None + pos_grad = None + + if ctx.needs_input_grad[0]: + q_grad = relative_position_attention_bwd_q(scores_grad, k, pos) + + if ctx.needs_input_grad[1]: + k_grad = relative_position_attention_bwd_k(scores_grad, q, pos) + + if ctx.needs_input_grad[2]: + max_seq_len = (pos.shape[1] + 1) // 2 + pos_grad = relative_position_attention_bwd_pos( + scores_grad, q, k, max_seq_len + ) + + return q_grad, k_grad, pos_grad + + +class RelativePositionAttentionModule(torch.nn.Module): + def forward( + self, q: torch.Tensor, k: torch.Tensor, pos: torch.Tensor + ) -> torch.Tensor: + """ + Args: + q: (batch, head, seq_q, channel) + k: (batch, head, seq_k, channel) + pos: (head, 2*max_seq_len-1, channel) + Returns: + scores: (batch, head, seq_q, seq_k) + """ + return RelativePositionAttentionFunction.apply(q, k, pos) + + +def _test(): + torch.manual_seed(20250820) + device = torch.device("cuda", 0) + b = 4 + h = 2 + seq_q = 100 + seq_k = 100 + c = 300 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + k_copy = k.clone() + pos_copy = pos.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + + q_copy.requires_grad_(True) + k_copy.requires_grad_(True) + pos_copy.requires_grad_(True) + + scores1 = RelativePositionAttentionModule()(q_copy, k_copy, pos_copy) + + s1 = (scale * scores1).sum() + s1.backward() + + print((s0 - s1).max().abs()) + print((q.grad - q_copy.grad).max().abs()) + print((k.grad - k_copy.grad).max().abs()) + print((pos.grad - pos_copy.grad).max().abs()) + """ + tensor(0.0005, device='cuda:0', grad_fn=) + tensor(7.6294e-06, device='cuda:0') + tensor(5.7220e-06, device='cuda:0') + tensor(3.4332e-05, device='cuda:0') + """ + + +if __name__ == "__main__": + _test() + pass diff --git a/egs/librispeech/ASR/zapformer2/scaling.py b/egs/librispeech/ASR/zapformer2/scaling.py new file mode 120000 index 0000000000..58e4b0a0fe --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/scaling_converter.py b/egs/librispeech/ASR/zapformer2/scaling_converter.py new file mode 120000 index 0000000000..bc7c7b5e37 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/scaling_converter.py @@ -0,0 +1 @@ +../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/speech_recognition.py b/egs/librispeech/ASR/zapformer2/speech_recognition.py new file mode 100755 index 0000000000..dd069cf3da --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/speech_recognition.py @@ -0,0 +1,229 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer2/streaming_beam_search.py b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py new file mode 120000 index 0000000000..97e6e733f2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py @@ -0,0 +1 @@ +../zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/streaming_decode.py b/egs/librispeech/ASR/zapformer2/streaming_decode.py new file mode 120000 index 0000000000..e31da07d01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/streaming_decode.py @@ -0,0 +1 @@ +../zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/subsampling.py b/egs/librispeech/ASR/zapformer2/subsampling.py new file mode 120000 index 0000000000..d178adc2e5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_scaling.py b/egs/librispeech/ASR/zapformer2/test_scaling.py new file mode 120000 index 0000000000..b776da79a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/test_scaling.py @@ -0,0 +1 @@ +../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_subsampling.py b/egs/librispeech/ASR/zapformer2/test_subsampling.py new file mode 120000 index 0000000000..2925ea3c51 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/test_subsampling.py @@ -0,0 +1 @@ +../zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/train.py b/egs/librispeech/ASR/zapformer2/train.py new file mode 100755 index 0000000000..4294e139f6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/train.py @@ -0,0 +1,1678 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Sched3, TransformedAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def get_adjusted_lr_batches(params: AttributeDict) -> float: + # returns an adjusted form of the "lr_batches" parameter used to set the learning + # rate in the Sched3 scheduler. + # We want the final LR to be based on the geometric mean of "how much data we + # have seen" and "how many batches we have seen". + # an easier way to look at it is this: the formula for learning rate depends + # on (cur_batch / lr_batches). if we write this as: + # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches + # then the numerator is a geometric mean of "how many batches we have seen" + # and "how much data we have seen". We can achieve this by setting + # lr_batches = params.lr_batches * (duration_ratio ** -0.5). + duration_ratio = (params.max_duration * params.world_size) / params.ref_duration + lr_batches = params.lr_batches * (duration_ratio ** -0.5) + logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") + return lr_batches + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="3,5,6,6,6,5", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--base-dim", + type=int, + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." + ) + + parser.add_argument( + "--embed-multiple", + type=int, + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", + ) + + parser.add_argument( + "--feedforward-multiple", + type=str, + default="3,3,3,3,3,3", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-multiple", + type=str, + default="4,6,9,12,9,6", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-multiple", + type=int, + default=8, + help="Factor by which embedding dimension in the decoder model is larger than base-dim.", + ) + + parser.add_argument( + "--joiner-multiple", + type=int, + default=8, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-multiple", + type=int, + default=8, + help="""Factor by which attention decoder dim is larger than base-dim""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-multiple", + type=int, + default=8, + help="""Determines attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-multiple", + type=int, + default=4, + help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=True, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--debug-interval", + type=int, + default=10, + help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for + finding parts of the network that are diverging or not well trained. + """ + ) + + parser.add_argument( + "--dump-debug-interval", + type=int, + default=0, + help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they + are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file + opened, seeked-in and closed for each scalar that is written). + """ + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=17500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--aux-loss-scale", + type=float, + default=0.05, + help="Scale on auxiliary losses that are defined in the model, such " + "as cosine loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--reconstruction-loss-scale", + type=float, + default=0.005, + help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", + ) + + parser.add_argument( + "--predict-loss-scale", + type=float, + default=0.01, + help="Prediction of random k-means after widest zipformer layer" + ) + + parser.add_argument( + "--stochastic-depth-prob", + type=float, + default=0.1, + help="Probability of using a randomly chosen stack output during training, instead of " + "final output." + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on pruned loss (for transducer) and the reconstruction and prediction + losses. Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=lookup(params, "embed_dim"), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + input_dim=lookup(params, "embed_dim"), + output_downsampling_factor=2, + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + cnn_module_kernel=lookup(params, "cnn_module_kernel"), + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero + causal=params.causal, + chunk_size=lookup(params, "chunk_size"), + left_context_frames=lookup(params, "left_context_frames"), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "decoder_dim"), + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + joiner = Joiner( + encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, + decoder_dim=lookup(params, "decoder_dim"), + joiner_dim=lookup(params, "joiner_dim"), + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "attention_decoder_dim"), + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=lookup(params, "attention_decoder_attention_dim"), + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), + memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + output_downsampling_factor = 2 + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), + decoder_dim=lookup(params, "decoder_dim"), + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + spec_augment: Optional[nn.Module] = None, + aux_loss_scale: float = 0.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + spec_augment: + The nn.Module instance (or similar object), used for training + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + num_copies = batch["num_copies"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + if num_copies > 1: + assert model.training + # will need the following for time-warping in nn.Module. + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + spec_augment = None # disable spec-aug + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=80, # for specaug + num_copies=num_copies, + aux_loss_scale=aux_loss_scale, + sd_prob=(params.stochastic_depth_prob if is_training else 0.0), + ) + + loss = 0.0 + + adjusted_batch_count = params.batch_idx_train + warm_step = params.warm_step + def warmup_schedule(scale, initial_factor): + # geometric warmup schedules. + warmup_factor = (1. if adjusted_batch_count >= warm_step else + initial_factor + (adjusted_batch_count / warm_step) * (1 - initial_factor)) + return scale * warmup_factor + + if params.use_transducer: + simple_loss_scale = params.simple_loss_scale + pruned_loss_scale = warmup_schedule(1.0, 0.05) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if num_copies > 1: + loss += params.cr_loss_scale * cr_loss + + reconstruction_loss_scale = params.reconstruction_loss_scale + + loss += reconstruction_loss_scale * reconstruction_loss + + if num_copies > 1: + loss += params.predict_loss_scale * predict_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + nframes = (feature_lens // params.subsampling_factor).sum().item() + if num_copies > 1: + nframes = nframes * (num_copies - 1) / num_copies # omit 1st copy + info["frames"] = nframes + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if num_copies > 1: + info["cr_loss"] = cr_loss.detach().cpu().item() + if num_copies > 1: + info["predict_loss"] = predict_loss.detach().cpu().item() + info["recon_loss"] = reconstruction_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + spec_augment: Optional[nn.Module] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + spec_augment: + The SpecAugment or similar instance used for CR-CTC. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def get_scaler_scale(): + if params.use_autocast and scaler._scale is not None: + return scaler._scale.item() + else: + return 1.0 + + def save_bad_model(suffix: str = ""): + if params.debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = get_scaler_scale() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = get_scaler_scale() + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + optimizer.write_debug_info(summary_writer=tb_writer) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + + assert params.use_ctc # for now, require CTC, we may remove this requirement later. + + spec_augment = ExpAugment() + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + debug_interval=params.debug_interval, + ) + + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and False: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + spec_augment=spec_augment, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[nn.Module] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer2/zipformer.py b/egs/librispeech/ASR/zapformer2/zipformer.py new file mode 100644 index 0000000000..f5e1afe779 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/zipformer.py @@ -0,0 +1,2066 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union +from relative_position_attention_module_optimized import RelativePositionAttentionFunction +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + OrthogonalLinear, + SimpleOrthogonalLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + ScaleLimiter, + ActivationDropoutAndLinear, + ExpNorm, + ChunkCausalDepthwiseConv1d, + CosineSimilarityLoss, + MinProductLoss, + MaxProductLoss, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, + with_loss, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + + dropout (float): dropout rate + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + def __init__( + self, + input_dim: int, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + query_head_dim: Union[int, Tuple[int]] = 24, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_multiple: Union[int, Tuple[int]] = 4, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + dropout: FloatLike = None, # see code below for default + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_multiple = _to_tuple(feedforward_multiple) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + # each one will be Zipformer2Encoder or OrthogonalDownsample or OrthogonalUpsample + encoders = [] + + num_encoders = len(downsampling_factor) + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + + + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_multiple=feedforward_multiple[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + num_conv_modules=(2 if downsampling_factor[i] == 1 else 1), + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + head_dim=query_head_dim[i], + dim=downsampling_factor[i]*input_dim, + out_proj=False, # (downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + sd_prob: + Stochastic-depth prob: with this probability we replace the final output + with the output of a randomly chosen stack (including the 'zero stack' which + means the original input x). Each stack except the 'zero stack' has a + separate output projection for stochastic depth, that only sees the + "non-bypass part", i.e. its encoder stack without the residual. + Returns: + Return (embeddings_lengths), where: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + chunk_size, left_context_chunks = self.get_chunk_info() + orig_seq_len = x.shape[0] + + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), + dim=0) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) + + num_stacks = len(self.downsampling_factor) + + x_sd = x + + def randomly_choose_seqs(x, this_x, prob: float): + batch_size = x.shape[1] + do_replace = (torch.rand(1, batch_size, 1, device=x.device) < prob).expand_as(x) + return torch.where(do_replace, this_x, x) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = downsample_by(x, ds) + T = x.shape[0] + x, this_x_sd = module( + x, + chunk_size=chunk_size, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) + ) + x = upsample_by(x, ds) + if sd_prob: + x_sd = randomly_choose_seqs(x_sd, upsample_by(this_x_sd, ds), 1. / (2. + i)) + + + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + od = self.output_downsampling_factor + x = downsample_by(x, od) + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + if sd_prob: + x_sd = downsample_by(x_sd, od) + x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + x = randomly_choose_seqs(x, x_sd, sd_prob) + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + new_states += new_layer_states + + x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. + + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def get_max_similarity(rank: int, power: float): + """ + This returns a value for the "max_similarity" argument of CosineSimilarityLoss. + the max_similarity is an upper limit we impose on the mean value of (x_i . x_j) + if i != j are two different sequence-position indexes and x_i and x_j are + activation vectors normalized to have unit length. + + rank: the dimension of the space, usually this is the num_channels, but if + we have just up-projected from a bottleneck, it would be the bottleneck + dimension. + power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean + we enforce the vector dimensions to be completely independent like Gaussian noise + (don't do this); if we set power=0.0 it would be equivalent to not having + the CosineSimilarityLoss at all. + + The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal + variable. If x consists of independent Gaussian noise of dimension D, with + variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" + would be close to a no-op for large D), then (x_i . x_j) would be distributed as + a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) + would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value + between this and 1, as a kind of heuristic limit on this max_similarity. + """ + return (0.7978845608 / (rank ** 0.5)) ** power + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +def pad_mask(mask: Optional[Tensor], seq_len: int): + # mask: (batch_size, old_seq_len) + # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). + if mask is None: + return None + (batch_size, old_seq_len) = mask.shape + pad = seq_len - old_seq_len + if pad == 0: + return mask + else: + return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), + dim=1) + + +def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) + return x + +def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + return x + + +def get_dct_matrix(N): + """ + Generates an orthonormal DCT-II matrix for a given size N. + Args: + N (int): The size of the square matrix. + Returns: + torch.Tensor: The N x N orthonormal DCT-II matrix. + """ + # Create the base matrix with dimensions (N, N) + mat = torch.zeros(N, N) + # Create a tensor for the indices k (rows) and n (columns) + k = torch.arange(N).unsqueeze(1) + n = torch.arange(N).unsqueeze(0) + # Fill the matrix using the DCT-II formula + mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) + # Adjust the first row (k=0) with a special normalization factor + mat[0] *= (2 ** -0.5) + return mat + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_multiple: determines the hidden dimension of the feedforward module + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + value_head_dim: int, + feedforward_multiple: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + num_conv_modules: int = 2, + causal: bool = False, + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + self.name = None # will be set from training loop + + self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + num_heads=2 * num_heads, + query_head_dim=query_head_dim, + dropout=0.0, + ) + + self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] + + feedforward_dim = embed_dim * feedforward_multiple + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + + if num_conv_modules >= 2: + self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + if num_conv_modules >= 1: + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + + self.scale_limiter = ScaleLimiter(max_var=2.0) + + self.norm = ExpNorm(embed_dim) + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, head_dim) or (batch_size, 2*seq_len-1, head_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale, + ) + num_heads = attn_weights.shape[0] // 2 # num heads per self_attn module + attn_weights1 = attn_weights[:num_heads] + attn_weights2 = attn_weights[num_heads//2:-num_heads//2] + attn_weights3 = attn_weights[num_heads:] + + src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + if hasattr(self, 'conv_module1'): + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + if hasattr(self, 'conv_module2'): + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) + offset = (src - src_orig) * residual_scale + src = src_orig + offset + + src = with_loss(src, + self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + None) + + src = self.scale_limiter(src) + + src = self.norm(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.residual(src_orig, src) + + src = self.norm(src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). + pos_dim: the dimension for the relative positional encoding +dropout: + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + + + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dim: int, + head_dim: int, + out_proj: bool, + ) -> None: + super().__init__() + + # self.downsample will also reverse the downsampling operation for us afterward. + self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) + self.proj.lr_scale = 0.75 + + self.encoder_pos = CompactRelPositionalEncoding( + head_dim, dropout_rate=0.0, length_factor=1.0 + ) + self.name = None + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.residual_scales = nn.Parameter( + torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), + (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], + dim=0)) + + self.copy_bypass = Identity() + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + + # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear + # module. + if out_proj: + self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) + self.out_proj.lr_scale = 0.75 + + # stochastic-depth proj. + self.sd_proj = nn.Linear(encoder_layer.embed_dim, dim) + + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + (out, out_sd), both of the same shape as src, + where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. + """ + pos_emb = self.encoder_pos(src) + + src_orig_fulldim = src + + src = self.proj(src) # project to layer dim. + + num_layers = len(self.layers) + src_orig = src + + residual_scale = limit_param_value(self.residual_scales[0], + min=-1.0, max=-0.5) + src_with_bypass = residual_scale * src + + for i, mod in enumerate(self.layers): + src = mod( + src, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, + ) + residual_scale = limit_param_value(self.residual_scales[i + 1], + min=0.0 if i + 1 < num_layers else 0.1, + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src + + + offset = src_with_bypass + + src = src_orig_fulldim + self.proj(offset, transpose=True) + # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, + # because of some identities involving orthogonal matrices. + + if aux_loss_scale: + src = with_loss(src, + self.offset_cosine_loss(offset.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask) + + self.cosine_loss(src.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), + None) + + src_sd = self.sd_proj(offset) + + if hasattr(self, 'out_proj'): + src = self.out_proj(src) + + return src, src_sd + + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + num_channels = src.shape[-1] + layer_dim = self.layers[0].embed_dim + if num_channels > layer_dim: + src, bypass = src[..., :layer_dim], src[..., layer_dim:] + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + src, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + src, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + if num_channels > layer_dim: + src = torch.cat((src, bypass), dim=-1) + + return src, new_states + + +class ResidualModule(nn.Module): + """ + An nn.Module that implements a learnable residual scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + function_scale_min: FloatLike = 0.1, + ): + super().__init__() + self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.function_scale_min = copy.deepcopy(function_scale_min) + + + def _get_scales(self): + function_scale = self.function_scale + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + function_scale = limit_param_value( + function_scale, min=float(self.function_scale_min), max=1.0, + ) + residual_scale = 1.0 - function_scale + return residual_scale, function_scale + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + residual_scale, function_scale = self._get_scales() + return residual_scale * src_orig + function_scale * src + + +class OrthogonalDownsample(torch.nn.Module): + """ + Downsamples on sequence axis by appending sequence-positions together, + and then optionally projects by an orthogonal matrix + + + +. Projection is initialized + in a special way and enforced to be orthogonal. + + Args: + channels: the number of input channels; the num output channels will be twice this + proj_dim: the number of channels, after combining 2 frames by interpolating their channels + as [ a b a b, .. ] that will actually be projected; the rest are just copied. + proj_dim=2 * channels would mean all channels are projected in a learned way + causal: True for causal systems, only affects error messages as requires even + input num frames. + """ + def __init__( + self, channels: int, proj_dim: int, causal: bool = False, + ): + super().__init__() + assert proj_dim <= channels * 2 + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + self.causal = causal + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + + if seq_len % 2 == 1: + if torch.jit.is_tracing(): + assert ( + not self.causal + ), f"pad should be zero for exporting streaming models. Given {pad}" + src = torch.cat((src, src[-1:]), dim=0) + seq_len += 1 + + # the following will place each 2 frames of a particular channel right after + # each other as if they were two different channels. + src = torch.stack((src[0::2], src[1::2]), dim=-1) + src = src.reshape(seq_len // 2, batch_size, in_channels * 2) + proj_channels = self.proj.weight.shape[0] + if proj_channels < in_channels * 2: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + return src + +class OrthogonalUpsample(torch.nn.Module): + """ + A very simple form of upsampling with an orthogonal matrix. + + proj_dim: the number of channels that will actually be projected; the rest are just copied. + proj_dim=channels would mean all channels are projected in a learned way + + """ + def __init__(self, channels: int, proj_dim: int): + super().__init__() + assert proj_dim <= channels + # gradually make smaller and then turn off the non-orthognality penalty. + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*2), batch_size, num_channels // 2) + """ + proj_channels = self.proj.weight.shape[0] + (seq_len, batch_size, in_channels) = src.shape + + if proj_channels < in_channels: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + + src = torch.stack((src[..., 0::2], src[..., 1::2]), + dim=1) # (seq_len, 2, batch_size, in_channels // 2) + src = src.reshape(seq_len * 2, batch_size, in_channels // 2) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125 * query_head_dim**-0.25 + ) + + + self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) + + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_query = Identity() + self.copy_key = Identity() + + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.6), (20000.0, 6.0), default=5.0)) + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, head_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.copy_key(k) + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + if aux_loss_scale: + k = with_loss(k, + self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads, + key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None), + None) + + + # time1 refers to target, time2 refers to source. + q = q.permute(1, 2, 0, 3) # (batch, head, time1, query_head_dim) + k = k.permute(1, 2, 0, 3) # (batch, head, time2, query_head_dim) + + if self.training: + k = with_loss(k, + self.qk_max_product(q.reshape(batch_size * num_heads, seq_len, query_head_dim), + k.reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads), + None) + + + attn_scores = RelativePositionAttentionFunction.apply(q.contiguous(), k.contiguous(), pos_emb.repeat(num_heads, 1, 1)) + + + assert attn_scores.shape == (batch_size, num_heads, seq_len, seq_len) + attn_scores = attn_scores.permute(1, 0, 2, 3) + # (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + + # HERE.. not finished streaming code. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, + bias=True, out_groups=num_heads) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + f = max(1.0, embed_dim / (num_heads * value_head_dim)) + + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) + + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + aux_loss_scale: float = 0.0, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + src_key_padding_mask: optional Tensor of shape (batch_size, src_seq_len); only + used for the cosine similarity loss, during training. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # x: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + if aux_loss_scale: + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), + aux_loss_scale, + mask=src_key_padding_mask), None) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwashL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.5, + ) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) + + + def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + x = self.in_proj(x) + x = self.out_proj(x) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwashR", + dropout_p=0.0, + initial_scale=0.05, + ) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.6)) + + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), + None) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + seq_len = 20 + # Just make sure the forward pass runs. + + input_dim = 50 + + c = Zipformer2( + input_dim=input_dim, + encoder_dim=(64, 96), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + + batch_size = 6 + seq_len = 21 + # Just make sure the forward pass runs. + f, lengths = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + sd_prob=0.1, + ) + f.sum().backward() + c.eval() + x_ = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + sd_prob=0.1, + ) + x_ # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From 8cc4256e2f6fce6735c1bf20701d4f3aa4fb129f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 02:27:11 +0800 Subject: [PATCH 0557/1191] Take some files from 1226 to allow parallel run. --- egs/librispeech/ASR/zapformer2/.gitignore | 1 + .../ASR/zapformer2/asr_datamodule.py | 454 ++++ .../ASR/zapformer2/attention_decoder.py | 1 + egs/librispeech/ASR/zapformer2/beam_search.py | 1 + egs/librispeech/ASR/zapformer2/ctc_decode.py | 1 + egs/librispeech/ASR/zapformer2/decode.py | 1089 +++++++++ .../ASR/zapformer2/decode_gigaspeech.py | 1 + .../ASR/zapformer2/decode_stream.py | 1 + egs/librispeech/ASR/zapformer2/decoder.py | 1 + .../ASR/zapformer2/encoder_interface.py | 1 + .../ASR/zapformer2/export-onnx-ctc.py | 1 + .../zapformer2/export-onnx-streaming-ctc.py | 1 + .../ASR/zapformer2/export-onnx-streaming.py | 1 + egs/librispeech/ASR/zapformer2/export-onnx.py | 1 + egs/librispeech/ASR/zapformer2/export.py | 1 + egs/librispeech/ASR/zapformer2/finetune.py | 1 + .../ASR/zapformer2/generate_averaged_model.py | 1 + .../ASR/zapformer2/jit_pretrained.py | 1 + .../ASR/zapformer2/jit_pretrained_ctc.py | 1 + .../zapformer2/jit_pretrained_streaming.py | 1 + egs/librispeech/ASR/zapformer2/joiner.py | 1 + .../ASR/zapformer2/label_smoothing.py | 1 + egs/librispeech/ASR/zapformer2/model.py | 630 +++++ egs/librispeech/ASR/zapformer2/my_profile.py | 1 + egs/librispeech/ASR/zapformer2/onnx_check.py | 1 + egs/librispeech/ASR/zapformer2/onnx_decode.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + .../zapformer2/onnx_pretrained-streaming.py | 1 + .../ASR/zapformer2/onnx_pretrained.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_H.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zapformer2/onnx_pretrained_ctc_HLG.py | 1 + .../onnx_pretrained_ctc_HLG_streaming.py | 1 + egs/librispeech/ASR/zapformer2/optim.py | 1 + egs/librispeech/ASR/zapformer2/pretrained.py | 1 + .../ASR/zapformer2/pretrained_ctc.py | 1 + .../relative_position_attention_bwd_k_2.py | 321 +++ .../relative_position_attention_bwd_pos_2.py | 321 +++ .../relative_position_attention_bwd_q_2.py | 332 +++ .../relative_position_attention_fwd_2.py | 302 +++ ...ive_position_attention_module_optimized.py | 118 + egs/librispeech/ASR/zapformer2/scaling.py | 1 + .../ASR/zapformer2/scaling_converter.py | 1 + .../ASR/zapformer2/speech_recognition.py | 229 ++ .../ASR/zapformer2/streaming_beam_search.py | 1 + .../ASR/zapformer2/streaming_decode.py | 1 + egs/librispeech/ASR/zapformer2/subsampling.py | 1 + .../ASR/zapformer2/test_scaling.py | 1 + .../ASR/zapformer2/test_subsampling.py | 1 + egs/librispeech/ASR/zapformer2/train.py | 1678 +++++++++++++ egs/librispeech/ASR/zapformer2/zipformer.py | 2066 +++++++++++++++++ 52 files changed, 7581 insertions(+) create mode 100644 egs/librispeech/ASR/zapformer2/.gitignore create mode 100755 egs/librispeech/ASR/zapformer2/asr_datamodule.py create mode 120000 egs/librispeech/ASR/zapformer2/attention_decoder.py create mode 120000 egs/librispeech/ASR/zapformer2/beam_search.py create mode 120000 egs/librispeech/ASR/zapformer2/ctc_decode.py create mode 100755 egs/librispeech/ASR/zapformer2/decode.py create mode 120000 egs/librispeech/ASR/zapformer2/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zapformer2/decode_stream.py create mode 120000 egs/librispeech/ASR/zapformer2/decoder.py create mode 120000 egs/librispeech/ASR/zapformer2/encoder_interface.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/export-onnx.py create mode 120000 egs/librispeech/ASR/zapformer2/export.py create mode 120000 egs/librispeech/ASR/zapformer2/finetune.py create mode 120000 egs/librispeech/ASR/zapformer2/generate_averaged_model.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/joiner.py create mode 120000 egs/librispeech/ASR/zapformer2/label_smoothing.py create mode 100755 egs/librispeech/ASR/zapformer2/model.py create mode 120000 egs/librispeech/ASR/zapformer2/my_profile.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_check.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_decode.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py create mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py create mode 120000 egs/librispeech/ASR/zapformer2/optim.py create mode 120000 egs/librispeech/ASR/zapformer2/pretrained.py create mode 120000 egs/librispeech/ASR/zapformer2/pretrained_ctc.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py create mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py create mode 120000 egs/librispeech/ASR/zapformer2/scaling.py create mode 120000 egs/librispeech/ASR/zapformer2/scaling_converter.py create mode 100755 egs/librispeech/ASR/zapformer2/speech_recognition.py create mode 120000 egs/librispeech/ASR/zapformer2/streaming_beam_search.py create mode 120000 egs/librispeech/ASR/zapformer2/streaming_decode.py create mode 120000 egs/librispeech/ASR/zapformer2/subsampling.py create mode 120000 egs/librispeech/ASR/zapformer2/test_scaling.py create mode 120000 egs/librispeech/ASR/zapformer2/test_subsampling.py create mode 100755 egs/librispeech/ASR/zapformer2/train.py create mode 100644 egs/librispeech/ASR/zapformer2/zipformer.py diff --git a/egs/librispeech/ASR/zapformer2/.gitignore b/egs/librispeech/ASR/zapformer2/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zapformer2/asr_datamodule.py b/egs/librispeech/ASR/zapformer2/asr_datamodule.py new file mode 100755 index 0000000000..4db6e101fb --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/asr_datamodule.py @@ -0,0 +1,454 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +# This K2SpeechRecognitionDataset is a modified version of one from +# lhotse.dataset, modified to, in training mode, to return a batch that has 3 +# different copies of the same data with the last two having different Musan +# augmentations and the first having none; and also include the key "num_copies" +# in the batch which would be 1 for the validation data (no Musan) and 3 for the +# training data with musan. +from speech_recognition import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer2/attention_decoder.py b/egs/librispeech/ASR/zapformer2/attention_decoder.py new file mode 120000 index 0000000000..830180a0cd --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/attention_decoder.py @@ -0,0 +1 @@ +../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/beam_search.py b/egs/librispeech/ASR/zapformer2/beam_search.py new file mode 120000 index 0000000000..8554e44ccf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/ctc_decode.py b/egs/librispeech/ASR/zapformer2/ctc_decode.py new file mode 120000 index 0000000000..a78e5c1df0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/ctc_decode.py @@ -0,0 +1 @@ +../zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode.py b/egs/librispeech/ASR/zapformer2/decode.py new file mode 100755 index 0000000000..221f01297b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode.py @@ -0,0 +1,1089 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py new file mode 120000 index 0000000000..63b0ef617b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py @@ -0,0 +1 @@ +../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode_stream.py b/egs/librispeech/ASR/zapformer2/decode_stream.py new file mode 120000 index 0000000000..4e59d04a12 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decode_stream.py @@ -0,0 +1 @@ +../zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decoder.py b/egs/librispeech/ASR/zapformer2/decoder.py new file mode 120000 index 0000000000..cab465d2b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/encoder_interface.py b/egs/librispeech/ASR/zapformer2/encoder_interface.py new file mode 120000 index 0000000000..aa5d0217a8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py new file mode 120000 index 0000000000..dc14e93e75 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..3baa2b673c --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py new file mode 120000 index 0000000000..d18cb9a9a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py @@ -0,0 +1 @@ +../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx.py b/egs/librispeech/ASR/zapformer2/export-onnx.py new file mode 120000 index 0000000000..f343cf7027 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export-onnx.py @@ -0,0 +1 @@ +../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export.py b/egs/librispeech/ASR/zapformer2/export.py new file mode 120000 index 0000000000..1a126ab695 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/export.py @@ -0,0 +1 @@ +../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/finetune.py b/egs/librispeech/ASR/zapformer2/finetune.py new file mode 120000 index 0000000000..0e9e7989b9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/finetune.py @@ -0,0 +1 @@ +../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/generate_averaged_model.py b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py new file mode 120000 index 0000000000..b65513a058 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py @@ -0,0 +1 @@ +../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained.py b/egs/librispeech/ASR/zapformer2/jit_pretrained.py new file mode 120000 index 0000000000..5d45825206 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py new file mode 120000 index 0000000000..43aeb684bf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py new file mode 120000 index 0000000000..8e5e6f9812 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/joiner.py b/egs/librispeech/ASR/zapformer2/joiner.py new file mode 120000 index 0000000000..444cb5f150 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/label_smoothing.py b/egs/librispeech/ASR/zapformer2/label_smoothing.py new file mode 120000 index 0000000000..3690afff9d --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/label_smoothing.py @@ -0,0 +1 @@ +../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/model.py b/egs/librispeech/ASR/zapformer2/model.py new file mode 100755 index 0000000000..278e498032 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/model.py @@ -0,0 +1,630 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, convert_num_channels, PredictLoss +from icefall.utils import add_sos, make_pad_mask, time_warp + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.predict_loss = PredictLoss(encoder_dim) + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.1, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.1, + ) + + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + nn.LogSoftmax(dim=-1), + ) + + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + + self.reconstruction_proj = ScaledLinear( + encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) + + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) + + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + + src_key_padding_mask = make_pad_mask(x_lens) # (N, T) + specaug_mask = specaug_mask[:, ::2] + assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 + specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale, + sd_prob=0.0) + + predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + + return encoder_out, encoder_out_lens, predict_loss + + + def compute_predict_loss(self, + encoder_out: Tensor, + src_key_padding_mask: Optional[Tensor], + specaug_mask: Optional[Tensor]) -> Tensor: + if src_key_padding_mask is not None and specaug_mask is not None: + mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) + elif src_key_padding_mask is not None: + mask = src_key_padding_mask.t().logical_not() + elif specaug_mask is not None: + mask = specaug_mask.t().logical_not() + else: + mask = None + return self.predict_loss(encoder_out, mask) + + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). This activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + return ctc_loss + + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss, with consistency regularization loss if we are in training mode. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.amp.autocast('cuda', enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.amp.autocast('cuda', enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + spec_augment: Optional[nn.Module] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + num_copies: int = 1, + aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + spec_augment: + The SpecAugment instance, or similar/compatible object, that masks + log-mel features. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. Used only for + time-warping, if num_copies > 1. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if num_copies > 1, corresponds to training mode. + num_copies: + the number of copies of the same data that are in the batch, e.g. 1, 2 + or 3; affects CRCTC, spec-augment, etc. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + Returns: + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + device = x.device + + if num_copies > 1: + assert num_copies == 3 # for now. + # will do SpecAugment or similar. + assert spec_augment is not None and getattr(spec_augment, 'time_warp_factor', -1) < 0 + + (batch_size, seq_len, num_channels) = x.shape + B = batch_size // num_copies + x = x.reshape(num_copies, B, seq_len, num_channels) + + do_time_warp = True + if do_time_warp: + # Apply time warping. First append the copies on the channel + # dimension so all copies get the exact same time-warping. + x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + + assert supervision_segments is not None + with torch.amp.autocast('cuda', enabled=False): + x = time_warp( + x.to(torch.float), + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments[:B], + ) + x = x.reshape(B, seq_len, num_copies, num_channels) + x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + + # x_no_specaug is several repeats of the 1st copy of the data, which + # is the one not augmented with Musan. But it does have time + # warping and mel warping. + x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( + B * (num_copies - 1), seq_len, num_channels) + + + # Independently apply frequency masking and time masking to all but the first + # copy of the data. + x = spec_augment(x[1:].reshape(-1, seq_len, num_channels)) + + x_lens = x_lens[:B*(num_copies-1)] + y = y[:B*(num_copies-1)] + else: + x_no_specaug = x + + + # Compute encoder outputs + encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale, + sd_prob=sd_prob) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + targets = y.values + if not self.training: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) + + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, + encoder_out_lens) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + + + def forward_reconstruction_loss(self, + log_mels: Tensor, + encoder_out: Tensor, + encoder_out_lens: Tensor): + """ + Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If + use_cr_ctc then we swap the first and second halves of the batch. + + Args: + log_mels: log-mel features of shape (batch_size, T, num_mels) + encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) + """ + batch_size = log_mels.shape[0] + num_mels = log_mels.shape[2] + + + def gauss_norm(x): + # normalize by gaussianizing on each dimension + values, indexes = x.sort(dim=1) # sort on seq dim + N = max(2, x.shape[1]) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[1], device=x.device, dtype=torch.float) + norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data + norm_rank = norm_rank.reshape(1, -1, 1) + norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) + x_norm = torch.empty_like(x) + x_norm.scatter_(dim=1, index=indexes, src=norm_rank) + return x_norm + + log_mels = gauss_norm(log_mels) + + pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) + T_embed = pred_mels.shape[1] + pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) + + excess_frames = log_mels.shape[1] - pred_mels.shape[1] + assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. + + T = pred_mels.shape[1] + offset = 3 # i found excess_frames = 5 one time. + log_mels = log_mels[:, offset:offset+T] + + lens = encoder_out_lens * 4 + pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions + assert pad_mask.shape == (batch_size, T) + pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position + # padd_mask: (batch_size, T, 1) + + + # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly + # helps to down-weight the effect of very silent silences. + #loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, + # reduction='none', beta=1.0) + # this way of applying the padding mask is not really ideal in terms of normalization, + # it will cause us to under-normalize a bit. + diff = log_mels * pad_mask - pred_mels * pad_mask + + loss = (diff ** 2) + + # removing the masking logic since we now use the no-specaug reference sequence. + ## masking. if it's different from the next item on both the frequency dim + ## and the time dim, it means we are in neither a time masked nor a frequency masked + ## position. + #mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), + # log_mels != torch.roll(log_mels, 1, dims=1)) + #loss = loss * mask.to(loss.dtype) + + loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. + return loss diff --git a/egs/librispeech/ASR/zapformer2/my_profile.py b/egs/librispeech/ASR/zapformer2/my_profile.py new file mode 120000 index 0000000000..76e48b756b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/my_profile.py @@ -0,0 +1 @@ +../zipformer/my_profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_check.py b/egs/librispeech/ASR/zapformer2/onnx_check.py new file mode 120000 index 0000000000..7293c70d46 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_check.py @@ -0,0 +1 @@ +../zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_decode.py b/egs/librispeech/ASR/zapformer2/onnx_decode.py new file mode 120000 index 0000000000..9e3faa5e01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_decode.py @@ -0,0 +1 @@ +../zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py new file mode 120000 index 0000000000..f8abb9daa5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py new file mode 120000 index 0000000000..11b846322e --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py new file mode 120000 index 0000000000..a085def837 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py new file mode 120000 index 0000000000..0c082a204f --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py new file mode 120000 index 0000000000..68102c7374 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py new file mode 120000 index 0000000000..8314b4efdf --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py new file mode 120000 index 0000000000..7a637a1c01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py new file mode 120000 index 0000000000..a5b04b3f8b --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/optim.py b/egs/librispeech/ASR/zapformer2/optim.py new file mode 120000 index 0000000000..207eecfcda --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained.py b/egs/librispeech/ASR/zapformer2/pretrained.py new file mode 120000 index 0000000000..70ad71ffc6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/pretrained.py @@ -0,0 +1 @@ +../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py new file mode 120000 index 0000000000..fb9bdf1fa2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py @@ -0,0 +1 @@ +../zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py new file mode 100755 index 0000000000..aa85d1fff7 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_k_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in scores_grad + BLOCK_N: tl.constexpr, # block size in q + BLOCK_C: tl.constexpr, # block size for seq_q + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_k, BLOCK_M) + num_pid_n = tl.cdiv(channels, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(BLOCK_M == 1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), for k, seq_k + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,), for j, channel + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n_mask = offs_n[:, None] < channels + + # (BLOCK_C,), for i, seq_q + offs_c = tl.arange(0, BLOCK_C) + + q_base = q_ptr + batch * stride_qb + head * stride_qh + offs_n[:, None] * stride_qc + k_base = k_ptr + batch * stride_kb + head * stride_kh + pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc + scores_grad_base = ( + scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sk + ) + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + scores_grad_mask = (offs_m < seq_k) & (c_idx[None, :] < seq_q) + + # (BLOCK_N, BLOCK_C), or (J, I) + q_mask = offs_n_mask & (c_idx[None, :] < seq_q) + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + rel_idx = c_idx[None, :] - offs_m + max_seq_len - 1 + + # (BLOCK_M, BLOCK_N, BLOCK_C), or (K, J, I), or (BLOCK_N, BLOCK_C) + pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask + + scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sq + q_ptrs = q_base + c_idx[None, :] * stride_qs + + # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + # (BLOCK_N, BLOCK_C), or (J, I) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C), or (J, I) + pos_ptrs = pos_base + rel_idx * stride_ps + + pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) + + # scores_grad_chunk (1, BLOCK_C), or (K, I) + # q_chunk (BLOCK_N, BLOCK_C), or (J, I) + # pos_chunk (BLOCK_N, BLOCK_C), or (J, I) + qp = q_chunk * pos_chunk + + acc += tl.sum(scores_grad_chunk * qp, axis=1) + + k_ptrs = k_base + offs_m * stride_ks + offs_n * stride_kc + k_mask = (offs_m < seq_k) & (offs_n < channels) + tl.store(k_ptrs, acc, mask=k_mask) + + +def relative_position_attention_bwd_k(scores_grad, q, pos): + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + assert q.is_contiguous() + assert pos.is_contiguous() + + assert scores_grad.ndim == q.ndim == 4, (scores_grad.shape, q.shape) + + assert pos.ndim == 3, pos.shape + b, h, seq_q, seq_k = scores_grad.shape + + assert q.shape[0] == b, q.shape + assert q.shape[1] == h, q.shape + assert q.shape[2] == seq_q, q.shape + + c = q.shape[3] + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert scores_grad.device == q.device == pos.device, ( + scores_grad.device, + q.device, + pos.device, + ) + + k = torch.empty(b, h, seq_k, c, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_k, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_k_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return k + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_k(scores_grad, q, pos), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + pos_copy = pos.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("k.grad", k.grad.shape, k.grad.sum()) + + scores_grad = scores0.grad.clone() + k_grad = relative_position_attention_bwd_k(scores_grad, q_copy, pos_copy) + + print(k_grad.shape, k_grad.sum()) + print((k.grad - k_grad).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py new file mode 100755 index 0000000000..93d1f09dc3 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 16, + "BLOCK_C": 16, + "GROUP_SIZE_M": 4, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_pos_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in q + BLOCK_N: tl.constexpr, # block size in k + BLOCK_C: tl.constexpr, # block size for channel + GROUP_SIZE_M: tl.constexpr, # size for grouped block, not used +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_n = tl.cdiv(seq_k, BLOCK_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(BLOCK_M == 1) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,) + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, 1) + rel_idx = offs_m - offs_n[:, None] + max_seq_len - 1 + + q_base = q_ptr + batch * stride_qb + head * stride_qh + k_base = k_ptr + batch * stride_kb + head * stride_kh + pos_base = pos_ptr + head * stride_ph + + scores_grad_base = scores_grad_ptr + batch * stride_sb + head * stride_sh + scores_grad_ptrs = ( + scores_grad_base + offs_m * stride_sq + offs_n[:, None] * stride_sk + ) + + # (BLOCK_N, 1) + scores_grad_mask = (offs_m < seq_q) & (offs_n[:, None] < seq_k) + + # (BLOCK_N, 1) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (1, BLOCK_C) + q_mask = (offs_m < seq_q) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C), or (K, J) + k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + pos_mask = ( + (rel_idx >= 0) + & (rel_idx < 2 * max_seq_len - 1) + & (c_idx[None, :] < channels) + ) + + q_ptrs = q_base + offs_m * stride_qs + c_idx[None, :] * stride_qc + k_ptrs = k_base + offs_n[:, None] * stride_ks + c_idx[None, :] * stride_kc + + # (1, BLOCK_C) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs = pos_base + rel_idx * stride_ps + c_idx[None, :] * stride_pc + + # q_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # scores_grad_chunk (BLOCK_N, 1) + # + # pos_chunk: (BLOCK_N, BLOCK_C) + qk = q_chunk * k_chunk + pos_chunk = scores_grad_chunk * qk + + tl.atomic_add(pos_ptrs, pos_chunk, mask=pos_mask) + + +def relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len): + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + + assert q.is_contiguous() + assert k.is_contiguous() + + assert scores_grad.ndim == q.ndim == k.ndim == 4, ( + scores_grad.shape, + q.shape, + k.shape, + ) + b, h, seq_q, seq_k = scores_grad.shape + c = q.shape[3] + + assert k.shape[0] == b, k.shape + assert k.shape[1] == h, k.shape + assert k.shape[2] == seq_k, k.shape + assert k.shape[3] == c, k.shape + + assert q.shape[0] == b, q.shape + assert q.shape[1] == h, q.shape + assert q.shape[2] == seq_q, q.shape + + assert scores_grad.device == q.device == k.device, ( + scores_grad.device, + q.device, + k.device, + ) + + pos = torch.zeros(h, 2 * max_seq_len - 1, c, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_pos_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return pos + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + k_copy = k.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("pos.grad", pos.grad.shape, pos.grad.sum()) + + pos_grad = relative_position_attention_bwd_pos( + scores0.grad, q_copy, k_copy, max_seq_len + ) + + print(pos_grad.shape, pos_grad.sum()) + print((pos.grad - pos_grad).abs().max()) + + +def main(): + # test_benchmark() + test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py new file mode 100755 index 0000000000..5a9ececf0c --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_bwd_q_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_grad_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in scores_grad + BLOCK_N: tl.constexpr, # block size in channels + BLOCK_C: tl.constexpr, # block size for seq_k + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_q, BLOCK_M) + num_pid_n = tl.cdiv(channels, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(BLOCK_M == 1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), we should always set BLOCK_M to 1 + offs_m = pid_m * BLOCK_M + + # (BLOCK_N,) for channels + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,), for seq_k + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, 1) + offs_n_mask = offs_n[:, None] < channels + + q_base = q_ptr + batch * stride_qb + head * stride_qh + k_base = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_kc + pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc + scores_grad_base = ( + scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sq + ) + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, seq_k, BLOCK_C): + c_idx = c + offs_c + + # (1, BLOCK_C) + rel_idx = offs_m - c_idx[None, :] + max_seq_len - 1 + + # (1, BLOCK_C) + scores_grad_mask = (offs_m < seq_q) & (c_idx[None, :] < seq_k) + + # (BLOCK_N, BLOCK_C) + k_mask = offs_n_mask & (c_idx[None, :] < seq_k) + + # (BLOCK_N, BLOCK_C) + pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask + + scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sk + k_ptrs = k_base + c_idx[None, :] * stride_ks + + # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) + scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs = pos_base + rel_idx * stride_ps + + pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) + + # scores_grad_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # pos_chunk (BLOCK_N, BLOCK_C) + + # kp: (BLOCK_N, BLOCK_C) + kp = k_chunk * pos_chunk + + acc += tl.sum(scores_grad_chunk * kp, axis=1) + + q_ptrs = q_base + offs_m * stride_qs + offs_n * stride_qc + q_mask = (offs_m < seq_q) & (offs_n < channels) + tl.store(q_ptrs, acc, mask=q_mask) + + +def relative_position_attention_bwd_q(scores_grad, k, pos): + """ + Args: + scores_grad: (b, h, seq_q, seq_k) + k: (b, h, seq_k, channels) + pos: (h, 2*max_seq_len-1, channels) + Returns: + grad of q: (b, h, seq_q, channels) + """ + if not scores_grad.is_contiguous(): + scores_grad = scores_grad.contiguous() + + assert scores_grad.is_contiguous(), ( + scores_grad.shape, + scores_grad.stride(0), + scores_grad.stride(1), + scores_grad.stride(2), + scores_grad.stride(3), + ) + assert k.is_contiguous() + assert pos.is_contiguous() + + assert scores_grad.ndim == k.ndim == 4, (scores_grad.shape, k.shape) + assert pos.ndim == 3, pos.shape + b, h, seq_q, seq_k = scores_grad.shape + + c = k.shape[3] + + assert k.shape[0] == b, (k.shape, scores_grad.shape) + assert k.shape[1] == h, (k.shape, scores_grad.shape) + assert k.shape[2] == seq_k, (k.shape, scores_grad.shape) + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert scores_grad.device == k.device == pos.device, ( + scores_grad.device, + k.device, + pos.device, + ) + + q = torch.empty(b, h, seq_q, c, device=k.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_bwd_q_kernel[grid]( + q, k, pos, scores_grad, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), + ) + # fmt: on + return q + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + line_names=["Triton"], + styles=[("green", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + max_seq_len = seq_q + + k = torch.randn(b, h, seq_k, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_bwd_q(scores_grad, k, pos), + quantiles=quantiles, + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 2 + seq_q = 250 + seq_k = 250 + c = 1025 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + k_copy = k.clone() + pos_copy = pos.clone() + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores0.retain_grad() + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + print("score0.grad", scores0.grad.shape, scores0.grad.sum()) + print("q.grad", q.grad.shape, q.grad.sum()) + + scores_grad = scores0.grad.clone() + q_grad = relative_position_attention_bwd_q(scores_grad, k_copy, pos_copy) + print(q_grad.shape, q_grad.sum()) + print((q.grad - q_grad).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py new file mode 100755 index 0000000000..e6ea552035 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +import triton.language as tl +import triton +import torch + + +def get_autotune_config(): + configs = [] + configs.append( + triton.Config( + { + "BLOCK_M": 1, + "BLOCK_N": 32, + "BLOCK_C": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=2, + ) + ) + return configs + + +@triton.autotune( + configs=get_autotune_config()[-2:], + key=["seq_q", "seq_k", "channels", "max_seq_len"], +) +@triton.jit +def relative_position_attention_fwd_kernel( + # fmt: off + q_ptr, # (batches, head, seq_q, channel) + k_ptr, # (batches, head, seq_k, channel) + pos_ptr, # (head, 2*max_seq_len-1, channel) + scores_ptr, # (batches, head, seq_q, seq_k) + B, H, seq_q, seq_k, channels, max_seq_len, # shape + stride_qb, stride_qh, stride_qs, stride_qc, # stride for q + stride_kb, stride_kh, stride_ks, stride_kc, # stride for k + stride_ph, stride_ps, stride_pc, # stride for pos + stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores + BLOCK_M: tl.constexpr, # block size in q + BLOCK_N: tl.constexpr, # block size in k + BLOCK_C: tl.constexpr, # block size for channel + GROUP_SIZE_M: tl.constexpr, # size for grouped block +): + # fmt: on + pid = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + head = pid_bh % H + batch = pid_bh // H + + num_pid_m = tl.cdiv(seq_q, BLOCK_M) + num_pid_n = tl.cdiv(seq_k, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + tl.assume(stride_qb > 0) + tl.assume(stride_qh > 0) + tl.assume(stride_qs > 0) + tl.assume(stride_qc > 0) + + tl.assume(stride_kb > 0) + tl.assume(stride_kh > 0) + tl.assume(stride_ks > 0) + tl.assume(stride_kc > 0) + + tl.assume(stride_ph > 0) + tl.assume(stride_ps > 0) + tl.assume(stride_pc > 0) + + tl.assume(stride_sb > 0) + tl.assume(stride_sh > 0) + tl.assume(stride_sq > 0) + tl.assume(stride_sk > 0) + + # (BLOCK_M,), we should always set BLOCK_M to 1 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # (BLOCK_N,) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # (BLOCK_C,) + offs_c = tl.arange(0, BLOCK_C) + + # (BLOCK_N, ) + rel_idx = offs_m - offs_n + max_seq_len - 1 + + # (BLOCK_N, 1) + rel_idx_mask = (rel_idx[:, None] >= 0) & (rel_idx[:, None] < 2 * max_seq_len - 1) + + q_ptrs = q_ptr + batch * stride_qb + head * stride_qh + offs_m[:, None] * stride_qs + k_ptrs = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_ks + + pos_ptrs = pos_ptr + head * stride_ph + rel_idx[:, None] * stride_ps + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for c in range(0, channels, BLOCK_C): + c_idx = c + offs_c + + # (BLOCK_M, BLOCK_C) + q_mask = (offs_m[:, None] < seq_q) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) + + # (BLOCK_N, BLOCK_C) + pos_mask = rel_idx_mask & (c_idx[None, :] < channels) + + q_ptrs0 = q_ptrs + c_idx[None, :] * stride_qc + k_ptrs0 = k_ptrs + c_idx[None, :] * stride_kc + + # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) + q_chunk = tl.load(q_ptrs0, mask=q_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + k_chunk = tl.load(k_ptrs0, mask=k_mask, other=0.0) + + # (BLOCK_N, BLOCK_C) + pos_ptrs0 = pos_ptrs + c_idx[None, :] * stride_pc + + pos_chunk = tl.load(pos_ptrs0, mask=pos_mask, other=0.0) + + # q_chunk (1, BLOCK_C) + # k_chunk (BLOCK_N, BLOCK_C) + # pos_chunk (BLOCK_N, BLOCK_C) + + acc += tl.sum(q_chunk * (k_chunk * pos_chunk), axis=1) + + scores_ptrs = ( + scores_ptr + + batch * stride_sb + + head * stride_sh + + offs_m * stride_sq + + offs_n * stride_sk + ) + scores_mask = (offs_m < seq_q) & (offs_n < seq_k) + + tl.store(scores_ptrs, acc, mask=scores_mask) + + +def relative_position_attention_fwd(q, k, pos): + assert q.is_contiguous() + assert k.is_contiguous() + assert pos.is_contiguous() + + assert q.ndim == k.ndim == 4, (q.shape, k.shape) + assert pos.ndim == 3, pos.shape + b, h, seq_q, c = q.shape + assert k.shape[0] == b, k.shape + assert k.shape[1] == h, k.shape + assert k.shape[3] == c, k.shape + + seq_k = k.shape[2] + + assert pos.shape[0] == h, pos.shape + pos.shape[2] == c, pos.shape + + max_seq_len = (pos.shape[1] + 1) // 2 + + assert q.device == k.device == pos.device, ( + q.device, + k.device, + pos.device, + ) + + scores = torch.empty(b, h, seq_q, seq_k, device=q.device) + + grid = lambda META: ( + triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), + b * h, + ) + + # fmt:off + relative_position_attention_fwd_kernel[grid]( + q, k, pos, scores, + b, h, seq_q, seq_k, c, max_seq_len, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + pos.stride(0), pos.stride(1), pos.stride(2), + scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3), + ) + # fmt: on + return scores + + +def relative_position_attention_fwd_torch(q, k, pos): + # this function consume a lot of memory, may OOM + max_seq_len = (pos.shape[1] + 1) // 2 + seq_q = q.shape[2] + seq_k = k.shape[2] + + q = q.unsqueeze(3) + k = k.unsqueeze(2) + + i = torch.arange(seq_q, device=q.device).unsqueeze(1) + j = torch.arange(seq_k, device=q.device).unsqueeze(0) + rel = (i - j) + max_seq_len - 1 + rel = rel.clamp(0, pos.shape[1] - 1) + pos_indexed = pos[:, rel].unsqueeze(0) + + # q: (b, h, seq_q, 1, c) + # q: (b, h, 1, seq_k, c) + # pos: (1, h, seq_q, seq_k, c) + scores = (q * k * pos_indexed).sum(dim=-1) + return scores + + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=[ + "b", + "h", + "seq_q", + "seq_k", + "c", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (b, h, seq, seq, c) + for b in [1, 2, 3] + # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation + for h in [2, 4] + for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + for c in [128, 256, 512] + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + line_vals=["triton", "torch"], + line_names=["Triton", "Torch"], + styles=[("green", "-"), ("blue", "-")], + ylabel="time (ms)", # Label name for the y-axis + plot_name="matmul-performance with pos", + args=dict(), + ) +) + + +@triton.testing.perf_report(configs) +def benchmark(b, h, seq_q, seq_k, c, provider): + device = torch.device("cuda", 0) + + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_fwd_torch(q, k, pos), + quantiles=quantiles, + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: relative_position_attention_fwd(q, k, pos), quantiles=quantiles + ) + return ms, max_ms, min_ms + + +def test_benchmark(): + benchmark.run(show_plots=False, print_data=True) + + +def test_correctness(): + device = torch.device("cuda", 0) + b = 2 + h = 8 + seq_q = 400 + seq_k = 400 + c = 1024 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + scores0 = relative_position_attention_fwd_torch(q, k, pos) + scores1 = relative_position_attention_fwd(q, k, pos) + print(scores0.shape, scores0.sum()) + print(scores1.shape, scores1.sum()) + print((scores0 - scores1).abs().max()) + + +def main(): + test_benchmark() + # test_correctness() + + +if __name__ == "__main__": + torch.manual_seed(20250812) + main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py new file mode 100755 index 0000000000..21640764ba --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +import torch + +from relative_position_attention_fwd_2 import ( + relative_position_attention_fwd, + relative_position_attention_fwd_torch, +) + +from relative_position_attention_bwd_q_2 import relative_position_attention_bwd_q +from relative_position_attention_bwd_k_2 import relative_position_attention_bwd_k +from relative_position_attention_bwd_pos_2 import relative_position_attention_bwd_pos + + +class RelativePositionAttentionFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, pos): + """ + Args: + q: (batch, head, seq_q, channel) + k: (batch, head, seq_k, channel) + pos: (head, 2*max_seq_len-1, channel) + Returns: + scores: (batch, head, seq_q, seq_k) + """ + ctx.save_for_backward(q, k, pos) + return relative_position_attention_fwd(q, k, pos) + + @staticmethod + def backward(ctx, scores_grad): + q, k, pos = ctx.saved_tensors + q_grad = None + k_grad = None + pos_grad = None + + if ctx.needs_input_grad[0]: + q_grad = relative_position_attention_bwd_q(scores_grad, k, pos) + + if ctx.needs_input_grad[1]: + k_grad = relative_position_attention_bwd_k(scores_grad, q, pos) + + if ctx.needs_input_grad[2]: + max_seq_len = (pos.shape[1] + 1) // 2 + pos_grad = relative_position_attention_bwd_pos( + scores_grad, q, k, max_seq_len + ) + + return q_grad, k_grad, pos_grad + + +class RelativePositionAttentionModule(torch.nn.Module): + def forward( + self, q: torch.Tensor, k: torch.Tensor, pos: torch.Tensor + ) -> torch.Tensor: + """ + Args: + q: (batch, head, seq_q, channel) + k: (batch, head, seq_k, channel) + pos: (head, 2*max_seq_len-1, channel) + Returns: + scores: (batch, head, seq_q, seq_k) + """ + return RelativePositionAttentionFunction.apply(q, k, pos) + + +def _test(): + torch.manual_seed(20250820) + device = torch.device("cuda", 0) + b = 4 + h = 2 + seq_q = 100 + seq_k = 100 + c = 300 + max_seq_len = seq_q + + q = torch.randn(b, h, seq_q, c, device=device) + k = torch.randn(b, h, seq_k, c, device=device) + pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) + + q_copy = q.clone() + k_copy = k.clone() + pos_copy = pos.clone() + + q.requires_grad_(True) + k.requires_grad_(True) + pos.requires_grad_(True) + + scores0 = relative_position_attention_fwd_torch(q, k, pos) + + scale = torch.rand_like(scores0) + + s0 = (scale * scores0).sum() + s0.backward() + + q_copy.requires_grad_(True) + k_copy.requires_grad_(True) + pos_copy.requires_grad_(True) + + scores1 = RelativePositionAttentionModule()(q_copy, k_copy, pos_copy) + + s1 = (scale * scores1).sum() + s1.backward() + + print((s0 - s1).max().abs()) + print((q.grad - q_copy.grad).max().abs()) + print((k.grad - k_copy.grad).max().abs()) + print((pos.grad - pos_copy.grad).max().abs()) + """ + tensor(0.0005, device='cuda:0', grad_fn=) + tensor(7.6294e-06, device='cuda:0') + tensor(5.7220e-06, device='cuda:0') + tensor(3.4332e-05, device='cuda:0') + """ + + +if __name__ == "__main__": + _test() + pass diff --git a/egs/librispeech/ASR/zapformer2/scaling.py b/egs/librispeech/ASR/zapformer2/scaling.py new file mode 120000 index 0000000000..58e4b0a0fe --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/scaling_converter.py b/egs/librispeech/ASR/zapformer2/scaling_converter.py new file mode 120000 index 0000000000..bc7c7b5e37 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/scaling_converter.py @@ -0,0 +1 @@ +../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/speech_recognition.py b/egs/librispeech/ASR/zapformer2/speech_recognition.py new file mode 100755 index 0000000000..dd069cf3da --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/speech_recognition.py @@ -0,0 +1,229 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer2/streaming_beam_search.py b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py new file mode 120000 index 0000000000..97e6e733f2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py @@ -0,0 +1 @@ +../zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/streaming_decode.py b/egs/librispeech/ASR/zapformer2/streaming_decode.py new file mode 120000 index 0000000000..e31da07d01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/streaming_decode.py @@ -0,0 +1 @@ +../zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/subsampling.py b/egs/librispeech/ASR/zapformer2/subsampling.py new file mode 120000 index 0000000000..d178adc2e5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_scaling.py b/egs/librispeech/ASR/zapformer2/test_scaling.py new file mode 120000 index 0000000000..b776da79a1 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/test_scaling.py @@ -0,0 +1 @@ +../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_subsampling.py b/egs/librispeech/ASR/zapformer2/test_subsampling.py new file mode 120000 index 0000000000..2925ea3c51 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/test_subsampling.py @@ -0,0 +1 @@ +../zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/train.py b/egs/librispeech/ASR/zapformer2/train.py new file mode 100755 index 0000000000..4294e139f6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/train.py @@ -0,0 +1,1678 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Sched3, TransformedAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def get_adjusted_lr_batches(params: AttributeDict) -> float: + # returns an adjusted form of the "lr_batches" parameter used to set the learning + # rate in the Sched3 scheduler. + # We want the final LR to be based on the geometric mean of "how much data we + # have seen" and "how many batches we have seen". + # an easier way to look at it is this: the formula for learning rate depends + # on (cur_batch / lr_batches). if we write this as: + # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches + # then the numerator is a geometric mean of "how many batches we have seen" + # and "how much data we have seen". We can achieve this by setting + # lr_batches = params.lr_batches * (duration_ratio ** -0.5). + duration_ratio = (params.max_duration * params.world_size) / params.ref_duration + lr_batches = params.lr_batches * (duration_ratio ** -0.5) + logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") + return lr_batches + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="3,5,6,6,6,5", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--base-dim", + type=int, + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." + ) + + parser.add_argument( + "--embed-multiple", + type=int, + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", + ) + + parser.add_argument( + "--feedforward-multiple", + type=str, + default="3,3,3,3,3,3", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-multiple", + type=str, + default="4,6,9,12,9,6", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-multiple", + type=int, + default=8, + help="Factor by which embedding dimension in the decoder model is larger than base-dim.", + ) + + parser.add_argument( + "--joiner-multiple", + type=int, + default=8, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-multiple", + type=int, + default=8, + help="""Factor by which attention decoder dim is larger than base-dim""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-multiple", + type=int, + default=8, + help="""Determines attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-multiple", + type=int, + default=4, + help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=True, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--debug-interval", + type=int, + default=10, + help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for + finding parts of the network that are diverging or not well trained. + """ + ) + + parser.add_argument( + "--dump-debug-interval", + type=int, + default=0, + help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they + are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. + Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file + opened, seeked-in and closed for each scalar that is written). + """ + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=17500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--aux-loss-scale", + type=float, + default=0.05, + help="Scale on auxiliary losses that are defined in the model, such " + "as cosine loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--reconstruction-loss-scale", + type=float, + default=0.005, + help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", + ) + + parser.add_argument( + "--predict-loss-scale", + type=float, + default=0.01, + help="Prediction of random k-means after widest zipformer layer" + ) + + parser.add_argument( + "--stochastic-depth-prob", + type=float, + default=0.1, + help="Probability of using a randomly chosen stack output during training, instead of " + "final output." + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on pruned loss (for transducer) and the reconstruction and prediction + losses. Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=lookup(params, "embed_dim"), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + input_dim=lookup(params, "embed_dim"), + output_downsampling_factor=2, + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + cnn_module_kernel=lookup(params, "cnn_module_kernel"), + dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero + causal=params.causal, + chunk_size=lookup(params, "chunk_size"), + left_context_frames=lookup(params, "left_context_frames"), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "decoder_dim"), + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + joiner = Joiner( + encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, + decoder_dim=lookup(params, "decoder_dim"), + joiner_dim=lookup(params, "joiner_dim"), + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "attention_decoder_dim"), + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=lookup(params, "attention_decoder_attention_dim"), + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), + memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + output_downsampling_factor = 2 + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), + decoder_dim=lookup(params, "decoder_dim"), + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + spec_augment: Optional[nn.Module] = None, + aux_loss_scale: float = 0.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + spec_augment: + The nn.Module instance (or similar object), used for training + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + num_copies = batch["num_copies"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + if num_copies > 1: + assert model.training + # will need the following for time-warping in nn.Module. + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + spec_augment = None # disable spec-aug + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=80, # for specaug + num_copies=num_copies, + aux_loss_scale=aux_loss_scale, + sd_prob=(params.stochastic_depth_prob if is_training else 0.0), + ) + + loss = 0.0 + + adjusted_batch_count = params.batch_idx_train + warm_step = params.warm_step + def warmup_schedule(scale, initial_factor): + # geometric warmup schedules. + warmup_factor = (1. if adjusted_batch_count >= warm_step else + initial_factor + (adjusted_batch_count / warm_step) * (1 - initial_factor)) + return scale * warmup_factor + + if params.use_transducer: + simple_loss_scale = params.simple_loss_scale + pruned_loss_scale = warmup_schedule(1.0, 0.05) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if num_copies > 1: + loss += params.cr_loss_scale * cr_loss + + reconstruction_loss_scale = params.reconstruction_loss_scale + + loss += reconstruction_loss_scale * reconstruction_loss + + if num_copies > 1: + loss += params.predict_loss_scale * predict_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + nframes = (feature_lens // params.subsampling_factor).sum().item() + if num_copies > 1: + nframes = nframes * (num_copies - 1) / num_copies # omit 1st copy + info["frames"] = nframes + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if num_copies > 1: + info["cr_loss"] = cr_loss.detach().cpu().item() + if num_copies > 1: + info["predict_loss"] = predict_loss.detach().cpu().item() + info["recon_loss"] = reconstruction_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + spec_augment: Optional[nn.Module] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + spec_augment: + The SpecAugment or similar instance used for CR-CTC. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def get_scaler_scale(): + if params.use_autocast and scaler._scale is not None: + return scaler._scale.item() + else: + return 1.0 + + def save_bad_model(suffix: str = ""): + if params.debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = get_scaler_scale() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = get_scaler_scale() + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + optimizer.write_debug_info(summary_writer=tb_writer) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + + assert params.use_ctc # for now, require CTC, we may remove this requirement later. + + spec_augment = ExpAugment() + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + debug_interval=params.debug_interval, + ) + + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and False: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + spec_augment=spec_augment, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[nn.Module] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer2/zipformer.py b/egs/librispeech/ASR/zapformer2/zipformer.py new file mode 100644 index 0000000000..f5e1afe779 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/zipformer.py @@ -0,0 +1,2066 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union +from relative_position_attention_module_optimized import RelativePositionAttentionFunction +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + OrthogonalLinear, + SimpleOrthogonalLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + ScaleLimiter, + ActivationDropoutAndLinear, + ExpNorm, + ChunkCausalDepthwiseConv1d, + CosineSimilarityLoss, + MinProductLoss, + MaxProductLoss, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, + with_loss, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + + dropout (float): dropout rate + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + def __init__( + self, + input_dim: int, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + query_head_dim: Union[int, Tuple[int]] = 24, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_multiple: Union[int, Tuple[int]] = 4, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + dropout: FloatLike = None, # see code below for default + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_multiple = _to_tuple(feedforward_multiple) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + # each one will be Zipformer2Encoder or OrthogonalDownsample or OrthogonalUpsample + encoders = [] + + num_encoders = len(downsampling_factor) + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + + + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_multiple=feedforward_multiple[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + num_conv_modules=(2 if downsampling_factor[i] == 1 else 1), + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + head_dim=query_head_dim[i], + dim=downsampling_factor[i]*input_dim, + out_proj=False, # (downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + sd_prob: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + sd_prob: + Stochastic-depth prob: with this probability we replace the final output + with the output of a randomly chosen stack (including the 'zero stack' which + means the original input x). Each stack except the 'zero stack' has a + separate output projection for stochastic depth, that only sees the + "non-bypass part", i.e. its encoder stack without the residual. + Returns: + Return (embeddings_lengths), where: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + chunk_size, left_context_chunks = self.get_chunk_info() + orig_seq_len = x.shape[0] + + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), + dim=0) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) + + num_stacks = len(self.downsampling_factor) + + x_sd = x + + def randomly_choose_seqs(x, this_x, prob: float): + batch_size = x.shape[1] + do_replace = (torch.rand(1, batch_size, 1, device=x.device) < prob).expand_as(x) + return torch.where(do_replace, this_x, x) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = downsample_by(x, ds) + T = x.shape[0] + x, this_x_sd = module( + x, + chunk_size=chunk_size, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) + ) + x = upsample_by(x, ds) + if sd_prob: + x_sd = randomly_choose_seqs(x_sd, upsample_by(this_x_sd, ds), 1. / (2. + i)) + + + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + od = self.output_downsampling_factor + x = downsample_by(x, od) + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + if sd_prob: + x_sd = downsample_by(x_sd, od) + x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + x = randomly_choose_seqs(x, x_sd, sd_prob) + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + new_states += new_layer_states + + x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. + + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def get_max_similarity(rank: int, power: float): + """ + This returns a value for the "max_similarity" argument of CosineSimilarityLoss. + the max_similarity is an upper limit we impose on the mean value of (x_i . x_j) + if i != j are two different sequence-position indexes and x_i and x_j are + activation vectors normalized to have unit length. + + rank: the dimension of the space, usually this is the num_channels, but if + we have just up-projected from a bottleneck, it would be the bottleneck + dimension. + power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean + we enforce the vector dimensions to be completely independent like Gaussian noise + (don't do this); if we set power=0.0 it would be equivalent to not having + the CosineSimilarityLoss at all. + + The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal + variable. If x consists of independent Gaussian noise of dimension D, with + variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" + would be close to a no-op for large D), then (x_i . x_j) would be distributed as + a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) + would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value + between this and 1, as a kind of heuristic limit on this max_similarity. + """ + return (0.7978845608 / (rank ** 0.5)) ** power + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +def pad_mask(mask: Optional[Tensor], seq_len: int): + # mask: (batch_size, old_seq_len) + # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). + if mask is None: + return None + (batch_size, old_seq_len) = mask.shape + pad = seq_len - old_seq_len + if pad == 0: + return mask + else: + return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), + dim=1) + + +def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) + return x + +def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + return x + + +def get_dct_matrix(N): + """ + Generates an orthonormal DCT-II matrix for a given size N. + Args: + N (int): The size of the square matrix. + Returns: + torch.Tensor: The N x N orthonormal DCT-II matrix. + """ + # Create the base matrix with dimensions (N, N) + mat = torch.zeros(N, N) + # Create a tensor for the indices k (rows) and n (columns) + k = torch.arange(N).unsqueeze(1) + n = torch.arange(N).unsqueeze(0) + # Fill the matrix using the DCT-II formula + mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) + # Adjust the first row (k=0) with a special normalization factor + mat[0] *= (2 ** -0.5) + return mat + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_multiple: determines the hidden dimension of the feedforward module + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + value_head_dim: int, + feedforward_multiple: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + num_conv_modules: int = 2, + causal: bool = False, + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + self.name = None # will be set from training loop + + self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + num_heads=2 * num_heads, + query_head_dim=query_head_dim, + dropout=0.0, + ) + + self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] + + feedforward_dim = embed_dim * feedforward_multiple + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + + if num_conv_modules >= 2: + self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + if num_conv_modules >= 1: + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + + self.scale_limiter = ScaleLimiter(max_var=2.0) + + self.norm = ExpNorm(embed_dim) + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, head_dim) or (batch_size, 2*seq_len-1, head_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale, + ) + num_heads = attn_weights.shape[0] // 2 # num heads per self_attn module + attn_weights1 = attn_weights[:num_heads] + attn_weights2 = attn_weights[num_heads//2:-num_heads//2] + attn_weights3 = attn_weights[num_heads:] + + src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + if hasattr(self, 'conv_module1'): + src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + if hasattr(self, 'conv_module2'): + src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) + offset = (src - src_orig) * residual_scale + src = src_orig + offset + + src = with_loss(src, + self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + None) + + src = self.scale_limiter(src) + + src = self.norm(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.residual(src_orig, src) + + src = self.norm(src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). + pos_dim: the dimension for the relative positional encoding +dropout: + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + + + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dim: int, + head_dim: int, + out_proj: bool, + ) -> None: + super().__init__() + + # self.downsample will also reverse the downsampling operation for us afterward. + self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) + self.proj.lr_scale = 0.75 + + self.encoder_pos = CompactRelPositionalEncoding( + head_dim, dropout_rate=0.0, length_factor=1.0 + ) + self.name = None + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.residual_scales = nn.Parameter( + torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), + (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], + dim=0)) + + self.copy_bypass = Identity() + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + + # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear + # module. + if out_proj: + self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) + self.out_proj.lr_scale = 0.75 + + # stochastic-depth proj. + self.sd_proj = nn.Linear(encoder_layer.embed_dim, dim) + + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + (out, out_sd), both of the same shape as src, + where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. + """ + pos_emb = self.encoder_pos(src) + + src_orig_fulldim = src + + src = self.proj(src) # project to layer dim. + + num_layers = len(self.layers) + src_orig = src + + residual_scale = limit_param_value(self.residual_scales[0], + min=-1.0, max=-0.5) + src_with_bypass = residual_scale * src + + for i, mod in enumerate(self.layers): + src = mod( + src, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, + ) + residual_scale = limit_param_value(self.residual_scales[i + 1], + min=0.0 if i + 1 < num_layers else 0.1, + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src + + + offset = src_with_bypass + + src = src_orig_fulldim + self.proj(offset, transpose=True) + # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, + # because of some identities involving orthogonal matrices. + + if aux_loss_scale: + src = with_loss(src, + self.offset_cosine_loss(offset.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask) + + self.cosine_loss(src.permute(1, 0, 2), + aux_loss_scale, src_key_padding_mask), + None) + + src_sd = self.sd_proj(offset) + + if hasattr(self, 'out_proj'): + src = self.out_proj(src) + + return src, src_sd + + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + num_channels = src.shape[-1] + layer_dim = self.layers[0].embed_dim + if num_channels > layer_dim: + src, bypass = src[..., :layer_dim], src[..., layer_dim:] + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + src, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + src, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + if num_channels > layer_dim: + src = torch.cat((src, bypass), dim=-1) + + return src, new_states + + +class ResidualModule(nn.Module): + """ + An nn.Module that implements a learnable residual scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + function_scale_min: FloatLike = 0.1, + ): + super().__init__() + self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.function_scale_min = copy.deepcopy(function_scale_min) + + + def _get_scales(self): + function_scale = self.function_scale + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + function_scale = limit_param_value( + function_scale, min=float(self.function_scale_min), max=1.0, + ) + residual_scale = 1.0 - function_scale + return residual_scale, function_scale + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + residual_scale, function_scale = self._get_scales() + return residual_scale * src_orig + function_scale * src + + +class OrthogonalDownsample(torch.nn.Module): + """ + Downsamples on sequence axis by appending sequence-positions together, + and then optionally projects by an orthogonal matrix + + + +. Projection is initialized + in a special way and enforced to be orthogonal. + + Args: + channels: the number of input channels; the num output channels will be twice this + proj_dim: the number of channels, after combining 2 frames by interpolating their channels + as [ a b a b, .. ] that will actually be projected; the rest are just copied. + proj_dim=2 * channels would mean all channels are projected in a learned way + causal: True for causal systems, only affects error messages as requires even + input num frames. + """ + def __init__( + self, channels: int, proj_dim: int, causal: bool = False, + ): + super().__init__() + assert proj_dim <= channels * 2 + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + self.causal = causal + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + + if seq_len % 2 == 1: + if torch.jit.is_tracing(): + assert ( + not self.causal + ), f"pad should be zero for exporting streaming models. Given {pad}" + src = torch.cat((src, src[-1:]), dim=0) + seq_len += 1 + + # the following will place each 2 frames of a particular channel right after + # each other as if they were two different channels. + src = torch.stack((src[0::2], src[1::2]), dim=-1) + src = src.reshape(seq_len // 2, batch_size, in_channels * 2) + proj_channels = self.proj.weight.shape[0] + if proj_channels < in_channels * 2: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + return src + +class OrthogonalUpsample(torch.nn.Module): + """ + A very simple form of upsampling with an orthogonal matrix. + + proj_dim: the number of channels that will actually be projected; the rest are just copied. + proj_dim=channels would mean all channels are projected in a learned way + + """ + def __init__(self, channels: int, proj_dim: int): + super().__init__() + assert proj_dim <= channels + # gradually make smaller and then turn off the non-orthognality penalty. + self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, + penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) + # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. + # it will be interpreted by get_parameter_groups_with_lrs() + self.proj.lr_scale = 0.75 + + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*2), batch_size, num_channels // 2) + """ + proj_channels = self.proj.weight.shape[0] + (seq_len, batch_size, in_channels) = src.shape + + if proj_channels < in_channels: + src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), + dim=-1) + else: + src = self.proj(src) + + src = torch.stack((src[..., 0::2], src[..., 1::2]), + dim=1) # (seq_len, 2, batch_size, in_channels // 2) + src = src.reshape(seq_len * 2, batch_size, in_channels // 2) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125 * query_head_dim**-0.25 + ) + + + self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) + + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_query = Identity() + self.copy_key = Identity() + + self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.6), (20000.0, 6.0), default=5.0)) + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, head_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.copy_key(k) + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + if aux_loss_scale: + k = with_loss(k, + self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads, + key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None), + None) + + + # time1 refers to target, time2 refers to source. + q = q.permute(1, 2, 0, 3) # (batch, head, time1, query_head_dim) + k = k.permute(1, 2, 0, 3) # (batch, head, time2, query_head_dim) + + if self.training: + k = with_loss(k, + self.qk_max_product(q.reshape(batch_size * num_heads, seq_len, query_head_dim), + k.reshape(batch_size * num_heads, seq_len, query_head_dim), + aux_loss_scale / num_heads), + None) + + + attn_scores = RelativePositionAttentionFunction.apply(q.contiguous(), k.contiguous(), pos_emb.repeat(num_heads, 1, 1)) + + + assert attn_scores.shape == (batch_size, num_heads, seq_len, seq_len) + attn_scores = attn_scores.permute(1, 0, 2, 3) + # (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + + # HERE.. not finished streaming code. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, + bias=True, out_groups=num_heads) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + f = max(1.0, embed_dim / (num_heads * value_head_dim)) + + self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) + + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + aux_loss_scale: float = 0.0, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + src_key_padding_mask: optional Tensor of shape (batch_size, src_seq_len); only + used for the cosine similarity loss, during training. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # x: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + if aux_loss_scale: + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), + aux_loss_scale, + mask=src_key_padding_mask), None) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwashL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.5, + ) + + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) + + + def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + x = self.in_proj(x) + x = self.out_proj(x) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwashR", + dropout_p=0.0, + initial_scale=0.05, + ) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.6)) + + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), + None) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + seq_len = 20 + # Just make sure the forward pass runs. + + input_dim = 50 + + c = Zipformer2( + input_dim=input_dim, + encoder_dim=(64, 96), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + + batch_size = 6 + seq_len = 21 + # Just make sure the forward pass runs. + f, lengths = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + sd_prob=0.1, + ) + f.sum().backward() + c.eval() + x_ = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + sd_prob=0.1, + ) + x_ # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From 45e5eb25acfc36b51bed70a5eb348ae33861f8ed Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 03:55:00 +0800 Subject: [PATCH 0558/1191] Reduce scaling_lr_scale fro 0.1 to 0.05. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ac4dd403e2..8869754cca 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -405,7 +405,7 @@ def __init__( direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scalar_lr_scale=0.1, - scaling_lr_scale=0.1, + scaling_lr_scale=0.05, eps=1.0e-08, weight_min_scale=0.005, weight_max_scale=1.0, From 0e3fd63c1835cd3dd023d5ad6f86a0a1ce84b42e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 04:19:46 +0800 Subject: [PATCH 0559/1191] Introduce scale_decay=0.01 to decay log scales to a default at log(0.05). --- egs/librispeech/ASR/zipformer/optim.py | 33 ++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8869754cca..050a90254b 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -15,11 +15,12 @@ # limitations under the License. import contextlib +import math import logging import random from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed from torch import Tensor @@ -239,14 +240,20 @@ def reverse_transform_param(group, p, orig_shape): # of scaling. max_scale = 0.7978845608028654 * (group["weight_max_scale"] if is_weight else group["bias_max_scale"]) min_scale = 0.7978845608028654 * (group["weight_min_scale"] if is_weight else group["bias_min_scale"]) - scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]).exp().clamp(min=min_scale, max=max_scale) + log_scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]) + + # the factor of 1.2533141373155001 is a factor we include in lr, to correct for a change to rms to mean-abs + # value. + scaling_lr = 1.2533141373155001 * group["scaling_lr_scale"] * group["lr"] + + # Apply weight-decay of log_scale, similar to weight decay of AdamW, except it regresses the + # log-scale to a default value instead of regressing the scale towards zero. + log_scale_default = group["log_scale_default"] + log_scale = ((log_scale - log_scale_default) * (1. - group["scale_decay"] * scaling_lr)) + log_scale_default + scale = log_scale.exp().clamp(min=min_scale, max=max_scale) q = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. q = q.reshape(*orig_shape) - # Now include the scaling factors. these were originally all zero as returned from - # forward_transform_param. - offset = numel + 2 # + 1 for the padding element and the log-scale. - return q @@ -385,6 +392,9 @@ class TransformedAdam(BatchedOptimizer): scale of each non-scalar parameter tensor. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the scaling factor on the learning rate of p_scale. + scale_decay: A constant similar to the weight_decay of AdamW, that applies on the scaling + factors, decaying them in log-space to scale_default. + scale_default: A constant that dictates the RMS value to which weight magnitudes decay. scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. eps: A general-purpose epsilon to prevent division by zero weight_min_scale, weight_max_scale: Minimum and maximum respectively of weight tensor @@ -404,6 +414,8 @@ def __init__( beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, + scale_decay=0.01, + scale_default=0.05, scalar_lr_scale=0.1, scaling_lr_scale=0.05, eps=1.0e-08, @@ -422,6 +434,8 @@ def __init__( beta1=beta1, direct=direct, beta2=beta2, + scale_decay=scale_decay, + log_scale_default=math.log(scale_default), scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, @@ -887,6 +901,8 @@ def __init__( beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, + scale_decay=0.01, + scale_default=0.05, scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, @@ -903,6 +919,8 @@ def __init__( beta1=beta1, direct=direct, beta2=beta2, + scale_decay=scale_decay, + log_scale_default=math.log(scale_default), scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, @@ -1553,7 +1571,8 @@ def _test_transformed_adam(hidden_dim: int): def _test_transform_params(): # caution: this has occasional errors. group = { "bias_min_scale": 0.001, "weight_min_scale": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, - "weight_max_scale": 20.0, "bias_max_scale": 20.0 } + "log_scale_default": 0.05, "scale_decay": 0.01, + "weight_max_scale": 20.0, "bias_max_scale": 20.0, "lr": 0.0} # lr set to 0.0 so weight-scale decay does not happen. for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: for shape in [ (1, 1), (2, 1), (2, 2), (2, 3, 4), (3, 10, 20), (4,) ]: p = scale * torch.randn(*shape) From ce6a09231978875b82a34c499a3e936297cb8f5a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Sep 2025 23:06:26 +0800 Subject: [PATCH 0560/1191] Increase layers from 3,5,6,6,6,5 to 4,6,7,7,7,6, adding one to each. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 18b842ca7c..17a7079379 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="3,5,6,6,6,5", + default="4,6,7,7,7,6", help="Number of zipformer encoder layers per stack, comma separated.", ) From 328b7a70f046625574e8f907535e392cff569868 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Sep 2025 00:27:23 +0800 Subject: [PATCH 0561/1191] Revert scaling_lr_scale from 0.05 to 0.1. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 050a90254b..bc2b7aa1ee 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -417,7 +417,7 @@ def __init__( scale_decay=0.01, scale_default=0.05, scalar_lr_scale=0.1, - scaling_lr_scale=0.05, + scaling_lr_scale=0.1, eps=1.0e-08, weight_min_scale=0.005, weight_max_scale=1.0, From 2362e25e29d97e9fd1fbdfdd5346f51e96338c3e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Sep 2025 03:27:57 +0800 Subject: [PATCH 0562/1191] Remove the normalization step in reverse_transform_params. --- egs/librispeech/ASR/zipformer/optim.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index bc2b7aa1ee..b6f9b70af0 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -207,7 +207,7 @@ def forward_transform_param(group, p): min_scale = 0.7978845608028654 * (group["weight_min_scale"] if is_weight else group["bias_min_scale"]) p_flat = p.reshape(batch_size, numel) abs_sum = p_flat.abs().sum(dim=1, keepdim=True) - min_abs_sum = min_scale * numel # if sumsq is less than this we pad with an extra element. + min_abs_sum = min_scale * numel # if abs_sum is less than this we pad with an extra element. abs_sum_clamped = abs_sum.clamp(min=min_abs_sum) pad = (abs_sum_clamped - abs_sum) scale = (abs_sum_clamped / numel) # must be nonzero thanks to min_abs_sum @@ -226,13 +226,6 @@ def reverse_transform_param(group, p, orig_shape): # numel is num elements of each parameter tensor in the batch. numel = p.shape[1] - 2 p_padded = p[:, :numel+1] # orig tensor plus one padding element - # the next line normalizes the scale to 1, because the update step will have - # changed it slightly versus the normalized state that forward_transform_param - # put it into. The correction factor (numel + 1) / numel is to account - # for the fact that it's actuallty the sum() / numel that should equal 1, - # but we prefer to use mean to avoid out-of-range numerical errors for large tensors - # if this code gets used in fp16 in the future. - p_padded = p_padded / (p_padded.abs().mean(dim=1, keepdim=True) * ((numel + 1) / numel)) is_weight = (len(orig_shape) > 2) # 0.7978845608028654 is sqrt(2/pi) which is a correction factor for the ratio of (abs value / rms value) From 75c60f57c0d9f712958669c31caf3f320b6d6408 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 28 Sep 2025 03:33:19 +0800 Subject: [PATCH 0563/1191] Remove schedules and dropout from Zipformer2. --- egs/librispeech/ASR/zapformer/train.py | 1 - egs/librispeech/ASR/zipformer/zipformer.py | 300 +-------------------- 2 files changed, 11 insertions(+), 290 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 17a7079379..8f63ec517a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -743,7 +743,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=lookup(params, "num_heads"), feedforward_multiple=lookup(params, "feedforward_multiple"), cnn_module_kernel=lookup(params, "cnn_module_kernel"), - dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero causal=params.causal, chunk_size=lookup(params, "chunk_size"), left_context_frames=lookup(params, "left_context_frames"), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3e6f2c0db7..7c5cab0168 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -37,10 +37,7 @@ CosineSimilarityLoss, MinProductLoss, MaxProductLoss, - Dropout2, FloatLike, - ScheduledFloat, - Whiten, convert_num_channels, limit_param_value, penalize_abs_values_gt, @@ -80,7 +77,6 @@ class Zipformer2(EncoderInterface): pos_dim (int): the dimension of each positional-encoding vector prior to projection, e.g. 128. - dropout (float): dropout rate causal (bool): if True, support chunkwise causal convolution. This should not hurt WER as no modeling power is lost, but the convolution modules will be slightly slower and use more memory. Enables use of the chunk_size and @@ -107,16 +103,12 @@ def __init__( feedforward_multiple: Union[int, Tuple[int]] = 4, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], ) -> None: super(Zipformer2, self).__init__() - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - def _to_tuple(x): """Converts a single int or a 1-tuple of an int to a tuple with the same length as downsampling_factor""" @@ -166,7 +158,6 @@ def _to_tuple(x): pos_head_dim=pos_head_dim[i], value_head_dim=value_head_dim[i], feedforward_multiple=feedforward_multiple[i], - dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], num_conv_modules=1, causal=causal, @@ -493,14 +484,6 @@ def get_max_similarity(rank: int, power: float): """ return (0.7978845608 / (rank ** 0.5)) ** power -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - def pad_mask(mask: Optional[Tensor], seq_len: int): # mask: (batch_size, old_seq_len) # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). @@ -560,7 +543,7 @@ class Zipformer2EncoderLayer(nn.Module): embed_dim: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). feedforward_multiple: determines the hidden dimension of the feedforward module - dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). Examples:: @@ -578,7 +561,6 @@ def __init__( pos_head_dim: int, value_head_dim: int, feedforward_multiple: int, - dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, num_conv_modules: int = 2, causal: bool = False, @@ -597,17 +579,16 @@ def __init__( num_heads=2 * num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, - dropout=0.0, ) self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] feedforward_dim = embed_dim * feedforward_multiple - self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4) - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim) - self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4) if num_conv_modules >= 2: self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) @@ -821,7 +802,6 @@ class Zipformer2Encoder(nn.Module): num_layers: the number of sub-encoder-layers in the encoder (required). dim: the dimension of the input and output (layer dim may be less than this). pos_dim: the dimension for the relative positional encoding -dropout: Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) @@ -846,7 +826,7 @@ def __init__( self.proj.lr_scale = 0.75 self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.0, length_factor=1.0 + pos_dim, length_factor=1.0 ) self.name = None self.layers = nn.ModuleList( @@ -1058,102 +1038,6 @@ def forward(self, src_orig: Tensor, src: Tensor): return residual_scale * src_orig + function_scale * src -class OrthogonalDownsample(torch.nn.Module): - """ - Downsamples on sequence axis by appending sequence-positions together, - and then optionally projects by an orthogonal matrix - - - -. Projection is initialized - in a special way and enforced to be orthogonal. - - Args: - channels: the number of input channels; the num output channels will be twice this - proj_dim: the number of channels, after combining 2 frames by interpolating their channels - as [ a b a b, .. ] that will actually be projected; the rest are just copied. - proj_dim=2 * channels would mean all channels are projected in a learned way - causal: True for causal systems, only affects error messages as requires even - input num frames. - """ - def __init__( - self, channels: int, proj_dim: int, causal: bool = False, - ): - super().__init__() - assert proj_dim <= channels * 2 - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 - self.causal = causal - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - - if seq_len % 2 == 1: - if torch.jit.is_tracing(): - assert ( - not self.causal - ), f"pad should be zero for exporting streaming models. Given {pad}" - src = torch.cat((src, src[-1:]), dim=0) - seq_len += 1 - - # the following will place each 2 frames of a particular channel right after - # each other as if they were two different channels. - src = torch.stack((src[0::2], src[1::2]), dim=-1) - src = src.reshape(seq_len // 2, batch_size, in_channels * 2) - proj_channels = self.proj.weight.shape[0] - if proj_channels < in_channels * 2: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - return src - -class OrthogonalUpsample(torch.nn.Module): - """ - A very simple form of upsampling with an orthogonal matrix. - - proj_dim: the number of channels that will actually be projected; the rest are just copied. - proj_dim=channels would mean all channels are projected in a learned way - - """ - def __init__(self, channels: int, proj_dim: int): - super().__init__() - assert proj_dim <= channels - # gradually make smaller and then turn off the non-orthognality penalty. - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, - penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 - - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*2), batch_size, num_channels // 2) - """ - proj_channels = self.proj.weight.shape[0] - (seq_len, batch_size, in_channels) = src.shape - - if proj_channels < in_channels: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - - src = torch.stack((src[..., 0::2], src[..., 1::2]), - dim=1) # (seq_len, 2, batch_size, in_channels // 2) - src = src.reshape(seq_len * 2, batch_size, in_channels // 2) - return src - class CompactRelPositionalEncoding(torch.nn.Module): """ @@ -1175,7 +1059,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: embed_dim: Embedding dimension. - dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives less weight to small differences of offset near the origin. @@ -1184,7 +1067,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): def __init__( self, embed_dim: int, - dropout_rate: FloatLike, max_len: int = 1000, length_factor: float = 1.0, ) -> None: @@ -1192,7 +1074,6 @@ def __init__( super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0, embed_dim - self.dropout = Dropout2(dropout_rate) self.pe = None assert length_factor >= 1.0, length_factor self.length_factor = length_factor @@ -1270,7 +1151,7 @@ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: :, ] pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) + return pos_emb class RelPositionMultiheadAttentionWeights(nn.Module): @@ -1288,7 +1169,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): num_heads: number of heads to compute weights for, e.g. 8 query_head_dim: dimension of the query (and key), per head. e.g. 24. pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on any given call to forward(), in training time. """ @@ -1300,14 +1180,12 @@ def __init__( num_heads: int, query_head_dim: int, pos_head_dim: int, - dropout: float = 0.0, ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim - self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. key_head_dim = query_head_dim @@ -1337,8 +1215,8 @@ def __init__( self.copy_query = Identity() self.copy_key = Identity() - self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.6), (20000.0, 6.0), default=5.0)) - self.pos_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.4), (20000.0, 4.0), default=5.0)) + self.qk_max_product = MaxProductLoss(max_product=6.0) + self.pos_max_product = MaxProductLoss(max_product=4.0) def forward( @@ -1491,10 +1369,6 @@ def forward( elif random.random() < 0.001: self._print_attn_entropy(attn_weights) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - return attn_weights def streaming_forward( @@ -1658,7 +1532,7 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) + self.cosine_loss = CosineSimilarityLoss(max_similarity=0.75) def forward( @@ -1768,7 +1642,7 @@ def streaming_forward( class FeedforwardModule(nn.Module): """Feedforward module in Zipformer2 model.""" - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + def __init__(self, embed_dim: int, feedforward_dim: int): super(FeedforwardModule, self).__init__() # try to get in the useful range of the activation function, i.e. not too small. self.in_proj = ScaledLinear(embed_dim, feedforward_dim) @@ -1781,8 +1655,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): feedforward_dim, embed_dim, activation="SwashL", - dropout_p=dropout, - dropout_shared_dim=0, + dropout_p=0.0, bias=True, initial_scale=0.5, ) @@ -1797,157 +1670,6 @@ def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: return x -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == ( - num_heads, - batch_size, - seq_len, - left_context_len + seq_len, - ) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, ( - cached_x.shape[2], - left_context_len, - ) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer2 model. From 0b78f97151a9f104ac27f39354e8375012017bcc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 27 Sep 2025 07:36:22 +0800 Subject: [PATCH 0564/1191] Un-tie the self attention weights. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7c5cab0168..19031bf00a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -576,7 +576,7 @@ def __init__( self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, - num_heads=2 * num_heads, + num_heads=3 * num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, ) @@ -638,10 +638,7 @@ def forward( key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale, ) - num_heads = attn_weights.shape[0] // 2 # num heads per self_attn module - attn_weights1 = attn_weights[:num_heads] - attn_weights2 = attn_weights[num_heads//2:-num_heads//2] - attn_weights3 = attn_weights[num_heads:] + attn_weights1, attn_weights2, attn_weights3 = attn_weights.chunk(3, dim=0) src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) From 90f43f4e6a15039f236724786ba1775fffa5ac2e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 29 Sep 2025 01:40:15 +0800 Subject: [PATCH 0565/1191] Increase scale_default from 0.05 to 0.2 --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b6f9b70af0..b84ab2d62b 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -408,7 +408,7 @@ def __init__( direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, - scale_default=0.05, + scale_default=0.2, scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, From 2f04b954468215a3c0c18e02761ea6db485103f4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 30 Sep 2025 04:39:38 +0800 Subject: [PATCH 0566/1191] In optim.py, remove the factors related to sqrt(2/pi)~0.8 --- egs/librispeech/ASR/zipformer/optim.py | 33 +++++--------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b84ab2d62b..841d5edd30 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -168,19 +168,6 @@ def momentum_step(group, p, state, grad): stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) state["delta"] = stored_delta - # 1.2533141373155001 is sqrt(pi/2) which is a correction factor for the - # ratio of (rms value / abs value) of a normal distribution, made when we - # switched from using rms values to abs value for purposes of scaling. This - # does not apply to scalar parameters (p.numel() == p.shape[0], dimension 0 - # is the same-sized-parameter-tensor batch dimension), which are not subject - # to scaling by inverse-absolute-values. The update is going to get - # multiplied by the mean-absolute-value, i.e. the scaling factor, which is - # equal to sqrt(2/pi) times the rms value for normally distributed data, and - # we want the step size to be the same as before for normally distributed - # data, which means we need to multiply by sqrt(pi/2). - lr = (1.2533141373155001 * lr if p.numel() > p.shape[0] else lr) - - stored_delta.mul_(beta1).add_(delta) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) @@ -201,10 +188,7 @@ def forward_transform_param(group, p): return p.reshape(batch_size, 1) / group["scalar_lr_scale"] is_weight = (p.ndim > 2) - # 0.7978845608028654 is sqrt(2/pi) which is a correction factor for the ratio of (abs value / rms value) - # of a normal distribution, made when we switched from using rms values to abs value for purposes - # of scaling. - min_scale = 0.7978845608028654 * (group["weight_min_scale"] if is_weight else group["bias_min_scale"]) + min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] p_flat = p.reshape(batch_size, numel) abs_sum = p_flat.abs().sum(dim=1, keepdim=True) min_abs_sum = min_scale * numel # if abs_sum is less than this we pad with an extra element. @@ -228,16 +212,11 @@ def reverse_transform_param(group, p, orig_shape): p_padded = p[:, :numel+1] # orig tensor plus one padding element is_weight = (len(orig_shape) > 2) - # 0.7978845608028654 is sqrt(2/pi) which is a correction factor for the ratio of (abs value / rms value) - # of a normal distribution, made when we switched from using rms values to abs value for purposes - # of scaling. - max_scale = 0.7978845608028654 * (group["weight_max_scale"] if is_weight else group["bias_max_scale"]) - min_scale = 0.7978845608028654 * (group["weight_min_scale"] if is_weight else group["bias_min_scale"]) + max_scale = group["weight_max_scale"] if is_weight else group["bias_max_scale"] + min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] log_scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]) - # the factor of 1.2533141373155001 is a factor we include in lr, to correct for a change to rms to mean-abs - # value. - scaling_lr = 1.2533141373155001 * group["scaling_lr_scale"] * group["lr"] + scaling_lr = group["scaling_lr_scale"] * group["lr"] # Apply weight-decay of log_scale, similar to weight decay of AdamW, except it regresses the # log-scale to a default value instead of regressing the scale towards zero. @@ -1501,9 +1480,9 @@ def _test_transformed_adam(hidden_dim: int): ] if test == 0: - optim = SimpleTransformedAdam(m.parameters(), lr=0.06, eps=1.0e-20) + optim = SimpleTransformedAdam(m.parameters(), lr=0.075, eps=1.0e-20) elif test == 1: - optim = TransformedAdam(m.named_parameters(), lr=0.06, clipping_scale=2.0, eps=1.0e-20) + optim = TransformedAdam(m.named_parameters(), lr=0.075, clipping_scale=2.0, eps=1.0e-20) elif test == 2: optim = Eve(m.parameters(), lr=0.003) else: From 64f733cf27e42e76b2f9b36b1b2a86fab7199b6a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 30 Sep 2025 04:42:24 +0800 Subject: [PATCH 0567/1191] Remove unused argument to optimizer --- egs/librispeech/ASR/zipformer/optim.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 841d5edd30..72405baa0a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -395,7 +395,6 @@ def __init__( weight_max_scale=1.0, bias_min_scale=1.0e-05, bias_max_scale=5.0, - size_update_period=4, clipping_update_period=100, debug_interval=0, ): From ea845cf4c2bf78f47f9664f966084c09dfd03807 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 30 Sep 2025 05:03:24 +0800 Subject: [PATCH 0568/1191] Change the formula of ExpNorm so that it is, in effect, TanhNorm. --- egs/librispeech/ASR/zipformer/scaling.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index df5791b035..6956eb9e65 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor]): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (1. - (-x_norm).exp()) + num = torch.nn.functional.tanh(x_norm) if floor is not None: num = torch.maximum(num, floor) scales = num / x_norm @@ -344,11 +344,16 @@ def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor class ExpNormFunction(torch.autograd.Function): # This computes: - # scales = (torch.mean(x ** 2 + eps, keepdim=True)) ** -0.5 * log_scale.exp() + # Equivalent to: + # x_norm = x.norm(dim=-1, keepdim=True) + # scales = x_norm.tanh() / x_norm * scale + # (scale is a user-provided scaling factor that is learnable).. # return x * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). + # + # .. if rand_floor != 0.0, it does a randomized method that sometimes modifies + # (increases) the scale if x_norm is less than rand_floor; this is intended + # to penalize too-small x values, which can otherwise occur after a lot of training + # and could destabilize the network's training. @staticmethod def forward( ctx, @@ -433,7 +438,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.25, + rand_floor: FloatLike = 0.15, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From 0704866727f4f612b1754c4a02acfbded3f4c1e0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 30 Sep 2025 05:13:13 +0800 Subject: [PATCH 0569/1191] Remove rand_floor from ExpNorm; add ScaleLimiter with min_rms,max_rms; change scaling of min_deviation,max_deviation in ScaleLimiter, as in 1257->1258; reduce min_rms in ScaleLimiter from 0.2 to 0.15. --- egs/librispeech/ASR/zipformer/scaling.py | 65 +++++++++++----------- egs/librispeech/ASR/zipformer/zipformer.py | 15 +++-- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6956eb9e65..00b3eb884c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1468,55 +1468,56 @@ def streaming_forward( class ScaleLimiterFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, max_var: float): + def forward(ctx, x: Tensor, min_rms: float, max_rms: float, aux_loss_scale: float, name: str): ctx.save_for_backward(x) - ctx.max_var = max_var + ctx.min_rms = min_rms + ctx.max_rms = max_rms + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name return x @staticmethod - def backward(ctx, y_grad: Tensor): + def backward(ctx, x_grad: Tensor): x, = ctx.saved_tensors - # you could think of loss_scale as like a mask, it's nonzero if - # (x**2).mean() > 1.0, but it starts of small if we are close to 1.0 - # so we don't suddenly add large gradients that could be destabilizing. - eps = 0.01 - loss_scale = eps * ((x.to(torch.float) ** 2).mean() - ctx.max_var).relu() - y_grad_abs_mean = y_grad.abs().mean() - # y_grad_abs_mean is a scaling factor for the gradient contribution, since we - # don't know at this point the total scale of the main loss. - - # the grad of (x ** 2).mean() would be 2 * x. we absorb the factor of 2 - # into eps, which is just an arbitrary smallish value. - return y_grad + (loss_scale * y_grad_abs_mean) * x, None + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + rms = (x ** 2).mean(dim=-1).sqrt() + max_deviation = (rms / ctx.max_rms - 1.).relu() + min_deviation = (1. - rms / ctx.min_rms).relu() + + if random.random() < 0.002: + logging.info( + f"ScaleLimiter: name={ctx.name}, min_rms={ctx.min_rms}, max_rms={ctx.max_rms}, " + f"min_deviation={min_deviation.mean()}, max_deviation={max_deviation.mean()}, " + f"loss_scale={ctx.aux_loss_scale}" + ) + (min_deviation + max_deviation).backward(gradient=torch.full_like(min_deviation, ctx.aux_loss_scale)) + return x_grad + x.grad, None, None, None, None class ScaleLimiter(torch.nn.Module): """ - Tries to make the average square value of the features no greater than self.max_var, by - adding a penalty. This is not per dimension, but globally. + Adds a penalty in backprop if the norm of any activation vector is less than min_rms + or more than max_rms. + Assumes channel dim is -1 and the input shape has >1 dimension. - Caution: max_var is actually a maximum variance. """ - def __init__(self, max_var: FloatLike = 1.0, prob: FloatLike = 1.0): + def __init__(self, min_rms: FloatLike, max_rms: FloatLike): super().__init__() self.name = None - self.max_var = max_var - self.prob = prob + self.min_rms = min_rms + self.max_rms = max_rms - def forward(self, x: Tensor) -> Tensor: + + def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return _no_op(x) else: - # this in effect adds a penalty to the loss function if - # (x ** 2).mean() > 1.0, the penalty will tend to reduce the value - # of (x ** 2). - if random.random() < 0.001: - logging.info(f"name={self.name}, max_var={float(self.max_var)}, prob={float(self.prob)}, x_rms={(x**2).mean().sqrt().item()}") - prob = float(self.prob) - if prob > 0 and random.random() < prob: - return ScaleLimiterFunction.apply(x, float(self.max_var)) - else: - return x + return ScaleLimiterFunction.apply(x, float(self.min_rms), float(self.max_rms), + aux_loss_scale, self.name) def penalize_abs_values_gt( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 19031bf00a..64aabc8f08 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -571,6 +571,7 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -595,9 +596,9 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(max_var=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.15, max_rms=2.0) - self.norm = ExpNorm(embed_dim) + self.norm = ExpNorm(embed_dim, rand_floor=0.0) # rely on scale_limiter for floor on norm. def forward( @@ -663,10 +664,16 @@ def forward( src = src_orig + offset src = with_loss(src, - self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) - src = self.scale_limiter(src) + # also put cosine_loss on src, mostly because it will be used in scale_limiter and we don't want the + # network to get around the scale limitation by using an offset. + src = with_loss(src, + self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), + None) + + src = self.scale_limiter(src, aux_loss_scale) src = self.norm(src) From 1c1be8edb3a63e4fc0acf606b406e0168098ee53 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 03:31:43 +0800 Subject: [PATCH 0570/1191] Remove dropout from CTC output. --- egs/librispeech/ASR/zapformer/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 278e498032..48117b7a31 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -115,7 +115,6 @@ def __init__( if use_ctc: # Modules for CTC head self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), nn.LogSoftmax(dim=-1), ) From b4cb9db417a1166d9f92e3bb93c0524ff8aa0021 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 05:38:32 +0800 Subject: [PATCH 0571/1191] Take account of padding mask when Gaussianizing log mels as targets for prediction; remove unnecessary load_state_dict_pre_hook. --- egs/librispeech/ASR/zapformer/model.py | 51 +++++++++++++++++--------- egs/librispeech/ASR/zipformer/optim.py | 13 ------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 48117b7a31..613947faa0 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -553,12 +553,42 @@ def forward( else: attention_decoder_loss = torch.empty(0) - reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, - encoder_out_lens) + reconstruction_loss = self.forward_reconstruction_loss(self.gauss_norm(x_no_specaug, x_lens), + encoder_out, encoder_out_lens) return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + + def gauss_norm(self, + log_mels: Tensor, + log_mel_lens: Tensor) -> Tensor: + (batch_size, seq_len, num_channels) = log_mels.shape + + randpos = torch.randint(seq_len, (batch_size, seq_len, num_channels), device=log_mels.device) + + rand_pos = rand_pos % log_mel_lens.unsqueeze(-1).unsqueeze(-1) + + + log_mels_rand = torch.gather(log_mels, dim=1, index=rand_pos) + length_mask = make_pad_mask(encoder_out_lens) # True in masked positions + length_mask = length_mask.unsqueeze(-1).expand_as(log_mels) + + log_mels = torch.where(length_mask, log_mels_rand, log_mels) + # OK, now for out-of-bounds positions we have selected randomly chosen within-bounds positions. + + values, indexes = log_mels.sort(dim=1) # sort on seq dim + N = max(2, log_mels.shape[1]) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, log_mels.shape[1], device=x.device, dtype=torch.float) + norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data + norm_rank = norm_rank.reshape(1, -1, 1) + norm_rank = norm_rank.repeat(log_mels.shape[0], 1, x.shape[2]) + log_mels_norm = torch.empty_like(log_mels) + log_mels_norm.scatter_(dim=1, index=indexes, src=norm_rank) + return log_mels_norm + + + def forward_reconstruction_loss(self, log_mels: Tensor, encoder_out: Tensor, @@ -574,21 +604,6 @@ def forward_reconstruction_loss(self, batch_size = log_mels.shape[0] num_mels = log_mels.shape[2] - - def gauss_norm(x): - # normalize by gaussianizing on each dimension - values, indexes = x.sort(dim=1) # sort on seq dim - N = max(2, x.shape[1]) - norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[1], device=x.device, dtype=torch.float) - norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data - norm_rank = norm_rank.reshape(1, -1, 1) - norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) - x_norm = torch.empty_like(x) - x_norm.scatter_(dim=1, index=indexes, src=norm_rank) - return x_norm - - log_mels = gauss_norm(log_mels) - pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) T_embed = pred_mels.shape[1] pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) @@ -613,7 +628,7 @@ def gauss_norm(x): # reduction='none', beta=1.0) # this way of applying the padding mask is not really ideal in terms of normalization, # it will cause us to under-normalize a bit. - diff = log_mels * pad_mask - pred_mels * pad_mask + diff = (log_mels - pred_mels) * pad_mask loss = (diff ** 2) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 72405baa0a..376d4a5aa2 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -320,16 +320,6 @@ def _write_debug_info(group, state, param_names, summary_writer): summary_writer.add_scalar(debug_str, value, step) - -def _load_state_dict_pre_hook(optim: Optimizer, state_dict: dict): - for optim_group, load_group in zip(optim.param_groups, state_dict['param_groups']): - for key in ['debug_interval']: - try: - optim_group[key] = load_group[key] - logging.info(f"Copied key {key}") - except KeyError: - logging.info(f"Could not copy key {key} from optim state-dict.") - class TransformedAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -428,9 +418,6 @@ def __init__( self.parameters_names = parameters_names - self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook) - - def _get_names_of_parameters( self, params_or_named_params From d1bc35867528c73b5b91b23f42f7d70c195e7fa8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 05:44:16 +0800 Subject: [PATCH 0572/1191] Remove stochastic depth and predict_loss from the code. --- egs/librispeech/ASR/zapformer/model.py | 9 +++---- egs/librispeech/ASR/zapformer/train.py | 23 +--------------- egs/librispeech/ASR/zipformer/zipformer.py | 31 ++-------------------- 3 files changed, 6 insertions(+), 57 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 613947faa0..4a72add0c4 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -130,7 +130,7 @@ def __init__( def forward_encoder( - self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, sd_prob: float = 0.0, + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute encoder outputs. Args: @@ -167,8 +167,7 @@ def forward_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, - aux_loss_scale=aux_loss_scale, - sd_prob=0.0) + aux_loss_scale=aux_loss_scale) predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) @@ -397,7 +396,6 @@ def forward( time_warp_factor: Optional[int] = 80, num_copies: int = 1, aux_loss_scale: float = 0.0, - sd_prob: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -501,8 +499,7 @@ def forward( # Compute encoder outputs encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, - aux_loss_scale=aux_loss_scale, - sd_prob=sd_prob) + aux_loss_scale=aux_loss_scale) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 8f63ec517a..86a95a5b00 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -543,21 +543,6 @@ def get_parser(): help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", ) - parser.add_argument( - "--predict-loss-scale", - type=float, - default=0.01, - help="Prediction of random k-means after widest zipformer layer" - ) - - parser.add_argument( - "--stochastic-depth-prob", - type=float, - default=0.1, - help="Probability of using a randomly chosen stack output during training, instead of " - "final output." - ) - parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -1003,7 +988,7 @@ def compute_loss( spec_augment = None # disable spec-aug with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -1015,7 +1000,6 @@ def compute_loss( time_warp_factor=80, # for specaug num_copies=num_copies, aux_loss_scale=aux_loss_scale, - sd_prob=(params.stochastic_depth_prob if is_training else 0.0), ) loss = 0.0 @@ -1042,9 +1026,6 @@ def warmup_schedule(scale, initial_factor): loss += reconstruction_loss_scale * reconstruction_loss - if num_copies > 1: - loss += params.predict_loss_scale * predict_loss - if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -1067,8 +1048,6 @@ def warmup_schedule(scale, initial_factor): info["ctc_loss"] = ctc_loss.detach().cpu().item() if num_copies > 1: info["cr_loss"] = cr_loss.detach().cpu().item() - if num_copies > 1: - info["predict_loss"] = predict_loss.detach().cpu().item() info["recon_loss"] = reconstruction_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 64aabc8f08..40771e152c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -211,7 +211,6 @@ def forward( x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, - sd_prob: float = 0.0, ) -> Tuple[Tensor, Tensor]: """ Args: @@ -227,12 +226,6 @@ def forward( If supplied, auxiliary losses such as CosineSimilarityLoss will be applied with this scale on the loss (note, these aux losses are reduced via summation over frames.) - sd_prob: - Stochastic-depth prob: with this probability we replace the final output - with the output of a randomly chosen stack (including the 'zero stack' which - means the original input x). Each stack except the 'zero stack' has a - separate output projection for stochastic depth, that only sees the - "non-bypass part", i.e. its encoder stack without the residual. Returns: Return (embeddings_lengths), where: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) @@ -257,18 +250,12 @@ def forward( num_stacks = len(self.downsampling_factor) - x_sd = x - - def randomly_choose_seqs(x, this_x, prob: float): - batch_size = x.shape[1] - do_replace = (torch.rand(1, batch_size, 1, device=x.device) < prob).expand_as(x) - return torch.where(do_replace, this_x, x) for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = downsample_by(x, ds) T = x.shape[0] - x, this_x_sd = module( + x = module( x, chunk_size=chunk_size, src_key_padding_mask=( @@ -283,8 +270,6 @@ def randomly_choose_seqs(x, this_x, prob: float): aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) ) x = upsample_by(x, ds) - if sd_prob: - x_sd = randomly_choose_seqs(x_sd, upsample_by(this_x_sd, ds), 1. / (2. + i)) assert self.output_downsampling_factor == 2, self.output_downsampling_factor @@ -299,11 +284,6 @@ def randomly_choose_seqs(x, this_x, prob: float): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - if sd_prob: - x_sd = downsample_by(x_sd, od) - x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding - x = randomly_choose_seqs(x, x_sd, sd_prob) - return x, lengths def _get_attn_mask( @@ -854,9 +834,6 @@ def __init__( self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) self.out_proj.lr_scale = 0.75 - # stochastic-depth proj. - self.sd_proj = nn.Linear(encoder_layer.embed_dim, dim) - def forward( self, @@ -925,12 +902,10 @@ def forward( aux_loss_scale, src_key_padding_mask), None) - src_sd = self.sd_proj(offset) - if hasattr(self, 'out_proj'): src = self.out_proj(src) - return src, src_sd + return src def streaming_forward( @@ -1868,7 +1843,6 @@ def _test_zipformer_main(causal: bool = False): torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), aux_loss_scale=1.0, - sd_prob=0.1, ) f.sum().backward() c.eval() @@ -1876,7 +1850,6 @@ def _test_zipformer_main(causal: bool = False): torch.randn(seq_len, batch_size, input_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), aux_loss_scale=1.0, - sd_prob=0.1, ) x_ # to remove flake8 warnings From 006b15e8efcdc9e4a940ee82cbb9c2d809cc7e9d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 05:57:00 +0800 Subject: [PATCH 0573/1191] Bug fix and improvement in efficiency --- egs/librispeech/ASR/zapformer/model.py | 46 +++++++------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 4a72add0c4..2f920fd654 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -86,8 +86,6 @@ def __init__( self.encoder_embed = encoder_embed self.encoder = encoder - self.predict_loss = PredictLoss(encoder_dim) - self.use_transducer = use_transducer if use_transducer: # Modules for Transducer head @@ -169,28 +167,11 @@ def forward_encoder( encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, aux_loss_scale=aux_loss_scale) - predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens, predict_loss - - - def compute_predict_loss(self, - encoder_out: Tensor, - src_key_padding_mask: Optional[Tensor], - specaug_mask: Optional[Tensor]) -> Tensor: - if src_key_padding_mask is not None and specaug_mask is not None: - mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) - elif src_key_padding_mask is not None: - mask = src_key_padding_mask.t().logical_not() - elif specaug_mask is not None: - mask = specaug_mask.t().logical_not() - else: - mask = None - return self.predict_loss(encoder_out, mask) + return encoder_out, encoder_out_lens def forward_ctc( @@ -498,8 +479,8 @@ def forward( # Compute encoder outputs - encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, - aux_loss_scale=aux_loss_scale) + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -553,7 +534,7 @@ def forward( reconstruction_loss = self.forward_reconstruction_loss(self.gauss_norm(x_no_specaug, x_lens), encoder_out, encoder_out_lens) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss @@ -562,30 +543,27 @@ def gauss_norm(self, log_mel_lens: Tensor) -> Tensor: (batch_size, seq_len, num_channels) = log_mels.shape - randpos = torch.randint(seq_len, (batch_size, seq_len, num_channels), device=log_mels.device) - + rand_pos = torch.randint(seq_len, (batch_size, seq_len, num_channels), device=log_mels.device) rand_pos = rand_pos % log_mel_lens.unsqueeze(-1).unsqueeze(-1) + arange = torch.arange(seq_len, device=log_mels.device)[None, :, None].expand_as(rand_pos) + length_mask = make_pad_mask(log_mel_lens) # True in masked positions - - log_mels_rand = torch.gather(log_mels, dim=1, index=rand_pos) - length_mask = make_pad_mask(encoder_out_lens) # True in masked positions + # select the "self" position if we are in the non-masked region; select random + # non-masked positions when in padding regions. length_mask = length_mask.unsqueeze(-1).expand_as(log_mels) - - log_mels = torch.where(length_mask, log_mels_rand, log_mels) - # OK, now for out-of-bounds positions we have selected randomly chosen within-bounds positions. + log_mels = torch.gather(log_mels, dim=1, index=torch.where(length_mask, rand_pos, arange)) values, indexes = log_mels.sort(dim=1) # sort on seq dim N = max(2, log_mels.shape[1]) - norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, log_mels.shape[1], device=x.device, dtype=torch.float) + norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, log_mels.shape[1], device=log_mels.device, dtype=torch.float) norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data norm_rank = norm_rank.reshape(1, -1, 1) - norm_rank = norm_rank.repeat(log_mels.shape[0], 1, x.shape[2]) + norm_rank = norm_rank.repeat(log_mels.shape[0], 1, log_mels.shape[2]) log_mels_norm = torch.empty_like(log_mels) log_mels_norm.scatter_(dim=1, index=indexes, src=norm_rank) return log_mels_norm - def forward_reconstruction_loss(self, log_mels: Tensor, encoder_out: Tensor, From 303c8132ce83e2dccaf3d7f37011a4484f6164ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 09:55:41 +0800 Subject: [PATCH 0574/1191] Bug fix to issue that biased rand positions. --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 2f920fd654..ed8cbc2cf8 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -543,7 +543,7 @@ def gauss_norm(self, log_mel_lens: Tensor) -> Tensor: (batch_size, seq_len, num_channels) = log_mels.shape - rand_pos = torch.randint(seq_len, (batch_size, seq_len, num_channels), device=log_mels.device) + rand_pos = torch.randint(100000000, (batch_size, seq_len, num_channels), device=log_mels.device) rand_pos = rand_pos % log_mel_lens.unsqueeze(-1).unsqueeze(-1) arange = torch.arange(seq_len, device=log_mels.device)[None, :, None].expand_as(rand_pos) length_mask = make_pad_mask(log_mel_lens) # True in masked positions From 9168852d9ea295118393ade7996337586158600f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 13:44:52 +0800 Subject: [PATCH 0575/1191] Replace tanh in ExpNorm with clamp(max=1.0) --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 00b3eb884c..14dbd2845e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor]): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = torch.nn.functional.tanh(x_norm) + num = num.clamp(max=1.0) if floor is not None: num = torch.maximum(num, floor) scales = num / x_norm From 44e56c7d1911f94807a43cf4e0e51c71712d570d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Oct 2025 13:56:00 +0800 Subject: [PATCH 0576/1191] Bug fix; apply scale_limiter in subsampling.py and do not use rand_floor. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/subsampling.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 14dbd2845e..724db5a4ba 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor]): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = num.clamp(max=1.0) + num = x_norm.clamp(max=1.0) if floor is not None: num = torch.maximum(num, floor) scales = num / x_norm diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 69df15072b..ad297a21a0 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -247,7 +247,9 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) - self.out_norm = ExpNorm(out_channels) + + self.scale_limiter = ScaleLimiter(min_rms=0.15, max_rms=2.0) + self.out_norm = ExpNorm(out_channels, rand_floor=0.0) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, @@ -287,6 +289,7 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) x = with_loss(x, self.cosine_loss(x, aux_loss_scale, key_padding_mask), None) + x = self.scale_limiter(x, aux_loss_scale) x = self.out_norm(x) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) From 0d7a93a4f697c7dde1c3a98225021a33472cd4de Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 06:35:46 +0800 Subject: [PATCH 0577/1191] In conv module, normalize input to sigmoid by dividing by rms. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 40771e152c..626f081e57 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1732,10 +1732,13 @@ def forward( """ + rms = (x ** 2).mean(dim=-1, keepdim=True).sqrt() + x = self.in_proj(x) # (time, batch, 2*channels) + x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) + s = self.sigmoid(s / rms) x = self.activation1(x) # identity. x = x * s x = self.activation2(x) # identity From 206fc86ce76f6fb65f5a1752a9c27e11a748f715 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 07:07:03 +0800 Subject: [PATCH 0578/1191] Revert clamp(max=1.0) to tanh() in ExpNorm. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 724db5a4ba..00b0f49ff5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor]): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = x_norm.clamp(max=1.0) + num = x_norm.tanh() if floor is not None: num = torch.maximum(num, floor) scales = num / x_norm From e0656177a431fd0a650855bf3700fd321fa7635c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 07:17:29 +0800 Subject: [PATCH 0579/1191] Change rms formula in conv_module2 to mostly affect larger values, via an offset of 0.2. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 626f081e57..5dac637dac 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1732,7 +1732,7 @@ def forward( """ - rms = (x ** 2).mean(dim=-1, keepdim=True).sqrt() + rms = ((x ** 2).mean(dim=-1, keepdim=True) + 0.2).sqrt() x = self.in_proj(x) # (time, batch, 2*channels) From 32b701dbaced58f58f04e52569714577e4244234 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 07:48:12 +0800 Subject: [PATCH 0580/1191] Remove max-product stuff and penalize too-large attention scores a different way. --- egs/librispeech/ASR/zipformer/zipformer.py | 65 ++++++++++++++++------ 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 40771e152c..418112dfd8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -35,8 +35,6 @@ ExpNorm, ChunkCausalDepthwiseConv1d, CosineSimilarityLoss, - MinProductLoss, - MaxProductLoss, FloatLike, convert_num_channels, limit_param_value, @@ -1133,6 +1131,47 @@ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: return pos_emb +class PenalizeLargeAttentionScores(torch.autograd.Function): + @staticmethod + def forward( + ctx, + attn_scores: Tensor, + limit: float, + aux_loss_scale: float, + name: str): + # attn_scores: (head, batch, query_time, key_time) + ctx.save_for_backward(attn_scores) + ctx.limit = limit + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return attn_scores + + @staticmethod + def backward( + ctx, + attn_scores_grad): + attn_scores, = ctx.saved_tensors + (num_heads, batch_size, seq_len, _) = attn_scores.shape + with torch.amp.autocast('cuda', enabled=False): + attn_scores = attn_scores.to(torch.float) + attn_scores = attn_scores.detach() + attn_scores.requires_grad = True + with torch.enable_grad(): + probs = attn_scores.softmax(dim=-1) + # attn_scores: (head, batch, query_time, key_time) + avg_scores = (attn_scores.abs() * probs).sum(dim=-1).mean(dim=(-2,-1)) + # avg_scores: (num_heads,), we want these to not exceed limit. + penalty = (avg_scores - ctx.limit).relu() + if random.random() < 0.001: + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, penalty={penalty}") + # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number + # of frames which is batch_size * seq_len. normalize by dividing by num heads. + penalty.backward(gradient=torch.full_like(penalty, ctx.aux_loss_scale * batch_size * seq_len / num_heads)) + return attn_scores_grad + attn_scores.grad + + + + class RelPositionMultiheadAttentionWeights(nn.Module): r"""Module that computes multi-head attention weights with relative position encoding. Various other modules consume the resulting attention weights: see, for example, the @@ -1194,9 +1233,6 @@ def __init__( self.copy_query = Identity() self.copy_key = Identity() - self.qk_max_product = MaxProductLoss(max_product=6.0) - self.pos_max_product = MaxProductLoss(max_product=4.0) - def forward( self, @@ -1260,15 +1296,8 @@ def forward( p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - if self.training: - k = with_loss(k, - self.qk_max_product(q.reshape(num_heads * batch_size, seq_len, query_head_dim), - k.permute(0, 1, 3, 2).reshape(num_heads * batch_size, seq_len, query_head_dim), - aux_loss_scale / num_heads), - None) - - attn_scores = torch.matmul(q, k) + # attn_scores: (head, batch, query_time, key_time) if True: # position scores. @@ -1282,10 +1311,6 @@ def forward( if self.training: pe = pos_emb.expand(num_heads, batch_size, pos_head_dim, seq_len2) pe = pe.reshape(num_heads * batch_size, pos_head_dim, seq_len2).permute(0, 2, 1) - p = with_loss(p, - self.pos_max_product(p.reshape(num_heads * batch_size, seq_len, pos_head_dim), pe, - aux_loss_scale / num_heads), - None) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] @@ -1337,6 +1362,12 @@ def forward( -1000, ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + attn_scores_limit = 4.0 # limit on our metric that reflects how much grad we are likely to backpropagate. + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, self.name) + + # We use our own version of softmax, defined in scaling.py, which should # save a little of the memory used in backprop by, if we are in # automatic mixed precision mode (amp / autocast), by only storing the From 9ab112366c3ac81a9147041db7c8309b8d34b769 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 07:52:56 +0800 Subject: [PATCH 0581/1191] Bug fix --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 418112dfd8..3926ccf112 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1163,11 +1163,11 @@ def backward( # avg_scores: (num_heads,), we want these to not exceed limit. penalty = (avg_scores - ctx.limit).relu() if random.random() < 0.001: - logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, penalty={penalty}") + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, penalty={penalty}") # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number # of frames which is batch_size * seq_len. normalize by dividing by num heads. penalty.backward(gradient=torch.full_like(penalty, ctx.aux_loss_scale * batch_size * seq_len / num_heads)) - return attn_scores_grad + attn_scores.grad + return attn_scores_grad + attn_scores.grad, None, None, None From 550514e6e5e43d00ced02d744714a257ba70f416 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 08:43:51 +0800 Subject: [PATCH 0582/1191] Change how penalty on scores is computed, have separate key and query excesses. --- egs/librispeech/ASR/zipformer/zipformer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a9a9b41aeb..cefa18e99b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1159,14 +1159,18 @@ def backward( with torch.enable_grad(): probs = attn_scores.softmax(dim=-1) # attn_scores: (head, batch, query_time, key_time) - avg_scores = (attn_scores.abs() * probs).sum(dim=-1).mean(dim=(-2,-1)) - # avg_scores: (num_heads,), we want these to not exceed limit. - penalty = (avg_scores - ctx.limit).relu() + scaled_scores = attn_scores.abs() * probs + query_scores = (scaled_scores.sum(dim=-1) - ctx.limit).relu() + key_scores = (scaled_scores.sum(dim=-2) - ctx.limit).relu() + if random.random() < 0.001: - logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, penalty={penalty}") + query_excess = query_scores.mean(dim=(1,2)) + key_excess = key_scores.mean(dim=(1,2)) + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, query_excess={query_excess}, key_excess={key_excess}") # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number # of frames which is batch_size * seq_len. normalize by dividing by num heads. - penalty.backward(gradient=torch.full_like(penalty, ctx.aux_loss_scale * batch_size * seq_len / num_heads)) + (query_scores + key_scores).backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / num_heads)) + return attn_scores_grad + attn_scores.grad, None, None, None From d28e5f141322384a548d626eb13e0f1bc0498463 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 08:54:37 +0800 Subject: [PATCH 0583/1191] Also divide by ctx.limit so it's like penalizing a relative excess. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index cefa18e99b..0fbba9e1a0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1169,7 +1169,8 @@ def backward( logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, query_excess={query_excess}, key_excess={key_excess}") # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number # of frames which is batch_size * seq_len. normalize by dividing by num heads. - (query_scores + key_scores).backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / num_heads)) + # also divide by ctx.limit so it's like penalizing a relative excess. + (query_scores + key_scores).backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / (num_heads * ctx.limit))) return attn_scores_grad + attn_scores.grad, None, None, None From 5a452833903cf74d90f3a64ba856ce65d6bd1160 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 11:32:20 +0800 Subject: [PATCH 0584/1191] Penalize only query excess not key excess; and increase limit from 4 to 5. --- egs/librispeech/ASR/zipformer/zipformer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0fbba9e1a0..2c1d31a4f1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1161,16 +1161,14 @@ def backward( # attn_scores: (head, batch, query_time, key_time) scaled_scores = attn_scores.abs() * probs query_scores = (scaled_scores.sum(dim=-1) - ctx.limit).relu() - key_scores = (scaled_scores.sum(dim=-2) - ctx.limit).relu() if random.random() < 0.001: query_excess = query_scores.mean(dim=(1,2)) - key_excess = key_scores.mean(dim=(1,2)) - logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, query_excess={query_excess}, key_excess={key_excess}") + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, query_excess={query_excess}") # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number # of frames which is batch_size * seq_len. normalize by dividing by num heads. # also divide by ctx.limit so it's like penalizing a relative excess. - (query_scores + key_scores).backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / (num_heads * ctx.limit))) + query_scores.backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / (num_heads * ctx.limit))) return attn_scores_grad + attn_scores.grad, None, None, None @@ -1369,7 +1367,7 @@ def forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - attn_scores_limit = 4.0 # limit on our metric that reflects how much grad we are likely to backpropagate. + attn_scores_limit = 5.0 # limit on our metric that reflects how much grad we are likely to backpropagate. attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, self.name) From c42c01eddfee6d87b4aa43af255999b9de75df7b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Oct 2025 12:10:07 +0800 Subject: [PATCH 0585/1191] Increase penalty from 5 to 8. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2c1d31a4f1..361575df1e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1367,7 +1367,7 @@ def forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - attn_scores_limit = 5.0 # limit on our metric that reflects how much grad we are likely to backpropagate. + attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, self.name) From 7588e1548fc6408762aceb7343210ed81556d4e1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Oct 2025 04:15:55 +0800 Subject: [PATCH 0586/1191] Reduce power of cosine loss of feeedforward module from .7 to .65, and offset loss of zipformer layer from .8 to .7. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 361575df1e..3fd49bb695 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -549,7 +549,7 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -1673,7 +1673,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int): initial_scale=0.5, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.65)) def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: From 7daad71436bccc29ec1bf79ab7301b8fb2f569d6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Oct 2025 06:55:15 +0800 Subject: [PATCH 0587/1191] Remove division by rms from ConvolutionMOdule. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3fd49bb695..7ed6c8b1a1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1765,14 +1765,11 @@ def forward( Tensor: Output tensor (#time, batch, channels). """ - - rms = ((x ** 2).mean(dim=-1, keepdim=True) + 0.2).sqrt() - x = self.in_proj(x) # (time, batch, 2*channels) x, s = x.chunk(2, dim=2) - s = self.sigmoid(s / rms) + s = self.sigmoid(s) x = self.activation1(x) # identity. x = x * s x = self.activation2(x) # identity From 02d6883eb677629ec49fc697be88ae055beda61f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Oct 2025 04:16:56 +0800 Subject: [PATCH 0588/1191] Introduce epsilon floor on ExpNorm. --- egs/librispeech/ASR/zipformer/scaling.py | 40 +++++++------------- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 00b0f49ff5..1c263dde72 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,11 +333,10 @@ def backward(ctx, x_grad, *args): -def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, floor: Optional[Tensor]): - x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = x_norm.tanh() - if floor is not None: - num = torch.maximum(num, floor) +def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps: float): + var = torch.mean(x ** 2, dim=channel_dim, keepdim=True) + x_norm, x_norm_witheps = var.sqrt(), (v + eps**2).sqrt() + num = x_norm_witheps.tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -350,33 +349,22 @@ class ExpNormFunction(torch.autograd.Function): # (scale is a user-provided scaling factor that is learnable).. # return x * scales # - # .. if rand_floor != 0.0, it does a randomized method that sometimes modifies - # (increases) the scale if x_norm is less than rand_floor; this is intended - # to penalize too-small x values, which can otherwise occur after a lot of training - # and could destabilize the network's training. @staticmethod def forward( ctx, x: Tensor, scale: Tensor, channel_dim: int, - rand_floor: float, + eps: float, ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim - ctx.rand_floor = rand_floor - if rand_floor != 0.0: - shape = list(x.shape) - shape[channel_dim] = 1 - floor = torch.where(torch.rand(*shape, device=x.device) < 0.1, rand_floor, 0.0) - else: - floor = None - ctx.floor = floor + ctx.eps = eps ctx.save_for_backward(x, scale) - return _exp_norm(x, scale, channel_dim, floor) + return _exp_norm(x, scale, channel_dim, eps) @staticmethod @@ -391,7 +379,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: scale.requires_grad = True with torch.enable_grad(): - ans = _exp_norm(x, scale, ctx.channel_dim, ctx.floor) + ans = _exp_norm(x, scale, ctx.channel_dim, ctx.eps) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -429,22 +417,20 @@ class ExpNorm(torch.nn.Module): interpreted as an offset from the input's ndim if negative. This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - rand_floor: if not 0.0: during training, for 10% of the vectors - we will randomly floor the numerator of the expression for the - scales (1. - (-x_norm).exp()), to this value. This is intended - to discourage the network to make the inputs smaller than this. + eps: a mechanism to discourage too-small inputs by making the function + nonlinear below approximately this value. """ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - rand_floor: FloatLike = 0.15, + eps: float = 0.1, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(1.7)) - self.rand_floor = rand_floor + self.eps = eps self.name = None @@ -459,7 +445,7 @@ def forward(self, x: Tensor) -> Tensor: self.scale, min=0.5, max=2.5, training=self.training) ans = ExpNormFunction.apply( - x, scale, self.channel_dim, float(self.rand_floor) if self.training else 0.0, + x, scale, self.channel_dim, self.eps, ) if random.random() < 0.002: diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index ad297a21a0..6dce6d7be4 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -249,7 +249,7 @@ def __init__( self.scale_limiter = ScaleLimiter(min_rms=0.15, max_rms=2.0) - self.out_norm = ExpNorm(out_channels, rand_floor=0.0) + self.out_norm = ExpNorm(out_channels) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7ed6c8b1a1..fe24eaf9bd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -576,7 +576,7 @@ def __init__( self.scale_limiter = ScaleLimiter(min_rms=0.15, max_rms=2.0) - self.norm = ExpNorm(embed_dim, rand_floor=0.0) # rely on scale_limiter for floor on norm. + self.norm = ExpNorm(embed_dim) def forward( From b2f081dae91fdd9b6314cbae972f33356acecb9b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Oct 2025 04:31:43 +0800 Subject: [PATCH 0589/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1c263dde72..3e623377da 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps: float): var = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - x_norm, x_norm_witheps = var.sqrt(), (v + eps**2).sqrt() + x_norm, x_norm_witheps = var.sqrt(), (var + eps**2).sqrt() num = x_norm_witheps.tanh() scales = num / x_norm scales = scale * scales From 14fae159c3d6e65dce27272737583b7d26efc0ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 04:32:15 +0800 Subject: [PATCH 0590/1191] Remove all scale limiters (for real this time) --- egs/librispeech/ASR/zipformer/subsampling.py | 3 --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 6dce6d7be4..4aabbf7b47 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -21,7 +21,6 @@ import torch from scaling import ( - ScaleLimiter, ScaledLinear, ExpNorm, FloatLike, @@ -248,7 +247,6 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) - self.scale_limiter = ScaleLimiter(min_rms=0.15, max_rms=2.0) self.out_norm = ExpNorm(out_channels) def forward( @@ -289,7 +287,6 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) x = with_loss(x, self.cosine_loss(x, aux_loss_scale, key_padding_mask), None) - x = self.scale_limiter(x, aux_loss_scale) x = self.out_norm(x) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fe24eaf9bd..6d3bab92aa 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -30,7 +30,6 @@ OrthogonalLinear, SimpleOrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - ScaleLimiter, ActivationDropoutAndLinear, ExpNorm, ChunkCausalDepthwiseConv1d, @@ -574,8 +573,6 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.15, max_rms=2.0) - self.norm = ExpNorm(embed_dim) @@ -651,8 +648,6 @@ def forward( self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) - src = self.scale_limiter(src, aux_loss_scale) - src = self.norm(src) return src From 81a7fdb08631687cb0b0cd8682d91f3f1cfb4ce0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 04:45:50 +0800 Subject: [PATCH 0591/1191] Add eps directly without square root formula --- egs/librispeech/ASR/zipformer/scaling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3e623377da..fd22dad655 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -334,9 +334,8 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps: float): - var = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - x_norm, x_norm_witheps = var.sqrt(), (var + eps**2).sqrt() - num = x_norm_witheps.tanh() + x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() + num = (x_norm + eps).tanh() scales = num / x_norm scales = scale * scales return (x * scales) From 903ba2f159146866ed11954690736a4e2d743c2c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 07:42:24 +0800 Subject: [PATCH 0592/1191] Reduce max of scale from 2.5 to 1.0. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fd22dad655..0bf75907c2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -441,7 +441,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.5, max=2.5, training=self.training) + self.scale, min=0.5, max=1.0, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, self.eps, From 86feb0f1f1de3d5c4e6bf6ea039e4280fa4ea7d8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 07:45:07 +0800 Subject: [PATCH 0593/1191] Revert max of scale to 2.5 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 0bf75907c2..fd22dad655 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -441,7 +441,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.5, max=1.0, training=self.training) + self.scale, min=0.5, max=2.5, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, self.eps, From c40d8dfe9dab5c9037b756c780dca43cf3bf81ea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 07:54:42 +0800 Subject: [PATCH 0594/1191] Add eps2=0.01 in EpsNorm, discourage very large inputs. --- egs/librispeech/ASR/zipformer/scaling.py | 33 +++++++++++------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fd22dad655..f539d4e741 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,37 +333,32 @@ def backward(ctx, x_grad, *args): -def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps: float): +def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps1: float, eps2: float): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (x_norm + eps).tanh() + num = (x_norm + eps1).tanh() - eps2 * x_norm scales = num / x_norm scales = scale * scales return (x * scales) class ExpNormFunction(torch.autograd.Function): - # This computes: - # Equivalent to: - # x_norm = x.norm(dim=-1, keepdim=True) - # scales = x_norm.tanh() / x_norm * scale - # (scale is a user-provided scaling factor that is learnable).. - # return x * scales - # @staticmethod def forward( ctx, x: Tensor, scale: Tensor, channel_dim: int, - eps: float, + eps1: float, + eps2: float, ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim - ctx.eps = eps + ctx.eps1 = eps1 + ctx.eps2 = eps2 ctx.save_for_backward(x, scale) - return _exp_norm(x, scale, channel_dim, eps) + return _exp_norm(x, scale, channel_dim, eps1, eps2) @staticmethod @@ -378,7 +373,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: scale.requires_grad = True with torch.enable_grad(): - ans = _exp_norm(x, scale, ctx.channel_dim, ctx.eps) + ans = _exp_norm(x, scale, ctx.channel_dim, ctx.eps1, ctx.eps2) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -386,7 +381,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(scale.grad), None, None + return x.grad, c(scale.grad), None, None, None @@ -423,13 +418,15 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.1, + eps1: float = 0.05, + eps2: float = 0.01, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(1.7)) - self.eps = eps + self.eps1 = eps1 + self.eps2 = eps2 self.name = None @@ -438,13 +435,13 @@ def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _exp_norm(x, self.scale, self.channel_dim) + return _exp_norm(x, self.scale, self.channel_dim, self.eps1, self.eps2) scale = limit_param_value( self.scale, min=0.5, max=2.5, training=self.training) ans = ExpNormFunction.apply( - x, scale, self.channel_dim, self.eps, + x, scale, self.channel_dim, self.eps1, self.eps2 ) if random.random() < 0.002: From 38950864838472557ec1223f7a94eca2dd568ebb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 09:46:49 +0800 Subject: [PATCH 0595/1191] Restore ScaleLImiter to zipformer layers, with max_rms=2.0 and min_rms=0.05 so hopefully inactive. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6d3bab92aa..2f3485990b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -39,6 +39,7 @@ limit_param_value, penalize_abs_values_gt, softmax, + ScaleLimiter, with_loss, ) from torch import Tensor, nn @@ -573,6 +574,8 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) + self.norm = ExpNorm(embed_dim) @@ -648,6 +651,9 @@ def forward( self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) + + src = self.scale_limiter(src, aux_loss_scale) + src = self.norm(src) return src From 7eb100a8360332a8eacf1dbfd484c94457bb2324 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Oct 2025 09:55:48 +0800 Subject: [PATCH 0596/1191] Add clamp(min=0.0) in formula for ExpNorm. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index f539d4e741..764165905e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps1: float, eps2: float): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (x_norm + eps1).tanh() - eps2 * x_norm + num = ((x_norm + eps1).tanh() - eps2 * x_norm).clamp(min=0.0) scales = num / x_norm scales = scale * scales return (x * scales) From 117813fbba460d0cc7b9a54b02e13dbd1a75ffcd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Oct 2025 11:23:02 +0800 Subject: [PATCH 0597/1191] Remove eps2 from formula; increase eps1 to 0.1 and the min limits of scale_limiters to 0.1. --- egs/librispeech/ASR/zipformer/scaling.py | 22 ++++++++------------ egs/librispeech/ASR/zipformer/subsampling.py | 3 +++ egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 764165905e..2ca8084784 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,9 +333,9 @@ def backward(ctx, x_grad, *args): -def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps1: float, eps2: float): +def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps1: float): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = ((x_norm + eps1).tanh() - eps2 * x_norm).clamp(min=0.0) + num = (x_norm + eps1).tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -348,17 +348,15 @@ def forward( scale: Tensor, channel_dim: int, eps1: float, - eps2: float, ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim ctx.eps1 = eps1 - ctx.eps2 = eps2 ctx.save_for_backward(x, scale) - return _exp_norm(x, scale, channel_dim, eps1, eps2) + return _exp_norm(x, scale, channel_dim, eps1) @staticmethod @@ -373,7 +371,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: scale.requires_grad = True with torch.enable_grad(): - ans = _exp_norm(x, scale, ctx.channel_dim, ctx.eps1, ctx.eps2) + ans = _exp_norm(x, scale, ctx.channel_dim, ctx.eps1) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -381,7 +379,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(scale.grad), None, None, None + return x.grad, c(scale.grad), None, None @@ -411,22 +409,20 @@ class ExpNorm(torch.nn.Module): interpreted as an offset from the input's ndim if negative. This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - eps: a mechanism to discourage too-small inputs by making the function + eps1: a mechanism to discourage too-small inputs by making the function nonlinear below approximately this value. """ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps1: float = 0.05, - eps2: float = 0.01, + eps1: float = 0.1, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(1.7)) self.eps1 = eps1 - self.eps2 = eps2 self.name = None @@ -435,13 +431,13 @@ def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _exp_norm(x, self.scale, self.channel_dim, self.eps1, self.eps2) + return _exp_norm(x, self.scale, self.channel_dim, self.eps1) scale = limit_param_value( self.scale, min=0.5, max=2.5, training=self.training) ans = ExpNormFunction.apply( - x, scale, self.channel_dim, self.eps1, self.eps2 + x, scale, self.channel_dim, self.eps1 ) if random.random() < 0.002: diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 4aabbf7b47..ff3746d896 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -21,6 +21,7 @@ import torch from scaling import ( + ScaleLimiter, ScaledLinear, ExpNorm, FloatLike, @@ -247,6 +248,7 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) + self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) self.out_norm = ExpNorm(out_channels) def forward( @@ -287,6 +289,7 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) x = with_loss(x, self.cosine_loss(x, aux_loss_scale, key_padding_mask), None) + x = self.scale_limiter(x, aux_loss_scale) x = self.out_norm(x) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2f3485990b..2394d1c375 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -574,7 +574,7 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) self.norm = ExpNorm(embed_dim) From 016ab103a9a601afba8447de45658aebfabf85ec Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Oct 2025 12:40:50 +0800 Subject: [PATCH 0598/1191] Restore min_rms and eps1 from 0.1 to 0.05; reduce min of scale from 0.5 to 0.4. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2ca8084784..8ff815d1df 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -416,7 +416,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps1: float = 0.1, + eps1: float = 0.05, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels @@ -434,7 +434,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim, self.eps1) scale = limit_param_value( - self.scale, min=0.5, max=2.5, training=self.training) + self.scale, min=0.4, max=2.5, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, self.eps1 diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index ff3746d896..d232f90542 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -248,7 +248,7 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) - self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) self.out_norm = ExpNorm(out_channels) def forward( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2394d1c375..2f3485990b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -574,7 +574,7 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) self.norm = ExpNorm(embed_dim) From 53774fca1594a935fd19b58dd7959ae93840d15c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Oct 2025 01:17:00 +0800 Subject: [PATCH 0599/1191] Reduce min_rms of ScaleLimiter instances from 0.05 to 0.02 --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index d232f90542..248aa0df6e 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -248,7 +248,7 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) - self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.02, max_rms=2.0) self.out_norm = ExpNorm(out_channels) def forward( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2f3485990b..447c909c74 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -574,7 +574,7 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.02, max_rms=2.0) self.norm = ExpNorm(embed_dim) From 2d6e99df0bc2c6a68c3e3e743d9da0b3fe1b8f8e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Oct 2025 09:06:10 +0800 Subject: [PATCH 0600/1191] Change the formula from (x + 0.05).tanh() to ((0.05 + x**0.8) ** (1./0.8)).tanh() --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8ff815d1df..174e1c1a4e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps1: float): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (x_norm + eps1).tanh() + num = ((x_norm ** 0.8 + eps1) ** 1. / 0.8).tanh() scales = num / x_norm scales = scale * scales return (x * scales) From e189c06abe1de14af5f084abad52ec6db4d69eee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Oct 2025 09:13:00 +0800 Subject: [PATCH 0601/1191] Increase default eps1 from 0.05 to 0.075. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 174e1c1a4e..27615fd8d6 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -416,7 +416,7 @@ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps1: float = 0.05, + eps1: float = 0.075, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels From c12197ee14b68e88cd402eac6af6ecc597c4865d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Oct 2025 09:32:34 +0800 Subject: [PATCH 0602/1191] Change formula to ((0.15 + x_norm ** 0.5) ** 2).tanh() --- egs/librispeech/ASR/zipformer/scaling.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 27615fd8d6..d7d406142e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,9 +333,10 @@ def backward(ctx, x_grad, *args): -def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int, eps1: float): +def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): + # plot (0.15 + 0.87*x**0.5)^(1/0.5), x+0.075 for 0 <= x <= 1 x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = ((x_norm ** 0.8 + eps1) ** 1. / 0.8).tanh() + num = ((0.15 + x_norm ** 0.5) ** 2).tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -347,16 +348,14 @@ def forward( x: Tensor, scale: Tensor, channel_dim: int, - eps1: float, ) -> Tensor: if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim - ctx.eps1 = eps1 ctx.save_for_backward(x, scale) - return _exp_norm(x, scale, channel_dim, eps1) + return _exp_norm(x, scale, channel_dim) @staticmethod @@ -371,7 +370,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: scale.requires_grad = True with torch.enable_grad(): - ans = _exp_norm(x, scale, ctx.channel_dim, ctx.eps1) + ans = _exp_norm(x, scale, ctx.channel_dim) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -379,8 +378,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(scale.grad), None, None - + return x.grad, c(scale.grad), None class ExpNorm(torch.nn.Module): @@ -409,20 +407,16 @@ class ExpNorm(torch.nn.Module): interpreted as an offset from the input's ndim if negative. This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - eps1: a mechanism to discourage too-small inputs by making the function - nonlinear below approximately this value. """ def __init__( self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps1: float = 0.075, ) -> None: super(ExpNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.scale = nn.Parameter(torch.tensor(1.7)) - self.eps1 = eps1 self.name = None @@ -431,13 +425,13 @@ def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _exp_norm(x, self.scale, self.channel_dim, self.eps1) + return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( self.scale, min=0.4, max=2.5, training=self.training) ans = ExpNormFunction.apply( - x, scale, self.channel_dim, self.eps1 + x, scale, self.channel_dim, ) if random.random() < 0.002: From 08931c171359225d4b8b8b0ca1c10b2d99708751 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Oct 2025 10:06:16 +0800 Subject: [PATCH 0603/1191] Implement CrossCosineLoss with max_product=0.2, in zipformer layers. --- egs/librispeech/ASR/zipformer/scaling.py | 69 ++++++++++++---------- egs/librispeech/ASR/zipformer/zipformer.py | 13 ++++ 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d7d406142e..74ff5706a1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1183,78 +1183,88 @@ def forward(self, z = with_loss(z, ret, None) where z is any quantity that will be used in calculating the main loss. Ret will always be numerically equal to zero in the forward pass but - may behave as if it were nonzero for backprop purposes. + will behave as if it were nonzero for backprop purposes. """ return MinProductLossFunction.apply(x, y, mask, float(self.min_product), loss_scale, self.name) -class MaxProductLossFunction(torch.autograd.Function): +# cross cosine loss is for when you have a situation like: +# x = x + delta +# x = with_loss(x, cross_cosine_loss(x, delta)) +# and we want to make sure that delta only represents, on average, +# a small fraction of the total x. That is, mean(abs((delta . x) / (x . x))) < max_product +# you could also probably supply the original x and it would have a somewhat similar effect. +class CrossCosineLossFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x: Tensor, y: Tensor, max_product: float, weight: float, name: str): - ctx.save_for_backward(x, y) + def forward(ctx, x: Tensor, delta: Tensor, mask: Optional[Tensor], + max_product: float, weight: float, name: str): + ctx.save_for_backward(x, delta) ctx.name = name + ctx.mask = mask # mask will have no grad so it should be OK to store this way ctx.weight = weight ctx.max_product = max_product + # return fake loss that is always zero but behaves in backprop as if it were a real loss. return torch.tensor(0.0, device=x.device, dtype=x.dtype) @staticmethod @custom_bwd def backward(ctx, ans_grad): - x, y = ctx.saved_tensors + x, delta = ctx.saved_tensors name = ctx.name # str + mask = ctx.mask # Tensor or None, shape: (batch_size, seq_len) weight = ctx.weight # float max_product = ctx.max_product # float + (batch_size, seq_len, num_channels) = x.shape with torch.enable_grad(): - x, y = x.detach(), y.detach() + x, delta = x.detach(), delta.detach() x.requires_grad = True - y.requires_grad = True + delta.requires_grad = True - (batch_size, seq_len, num_channels) = x.shape - seq_len2 = y.shape[1] - indexes = torch.randint(0, seq_len2, (batch_size, seq_len, 1), device=x.device) + eps = 3.0e-08 # won't be zero in float16 - y_rand = torch.gather(y, 1, indexes.expand(*x.shape)) + product = (x * delta).sum(dim=-1).abs() / ((x * x).sum(dim=-1) + eps) - product = (x * y_rand).sum(dim=-1).abs() + if mask is not None: + product = product * (~mask).to(product.dtype) + product = product.abs() excess_product = (product.sum(dim=1) - seq_len * max_product).relu() if random.random() < 0.001: - logging.info(f"MaxProduct: {name}, limit={max_product}, excess-product={excess_product.mean() / seq_len}") + logging.info(f"CrossCosineLoss: {name}, limit={max_product}, excess-product={excess_product.mean() / seq_len}") grad = (weight * ans_grad).expand(excess_product.numel()) excess_product.backward(grad) - return x.grad, y.grad, None, None, None + return x.grad, delta.grad, None, None, None, None -class MaxProductLoss(nn.Module): +class CrossCosineLoss(nn.Module): def __init__(self, - max_product: FloatLike): # e.g. 20.0 for max_product + max_product: FloatLike): # e.g. 0.2. super().__init__() self.max_product = max_product self.name = None def forward(self, x: Tensor, - y: Tensor, - loss_scale: float) -> Tensor: + delta: Tensor, + loss_scale: float, + mask: Optional[Tensor]) -> Tensor: """ - Compute loss that limits the average dot product (without normalization) - between x, and (y, but randomly permuted on the sequence dimension). It is - intended for limiting dot-products of queries and keys. + Compute loss that limits the average value over the sequence of abs((delta . x) / (x . x)) + x: Tensor of shape (batch_size, seq_len, num_channels) - y: Tensor of shape (batch_size, seq_len2, num_channels) [seq_len2 does not have to equal seq_len]. + delta: Tensor of shape (batch_size, seq_len, num_channels) loss_scale: the scale with which the loss should be incorporated into the graph. This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). We divide this by max_product, - so that it penalizes relative, not absolute, violations of the max-product - rule. - The loss will be summed over frames of x, and multiplied by this value. + automatic mixed precision training (amp). + The loss will be summed over frames of x, i.e. scaled like + batch_size * seq_len * loss_scale * [average excess product] Returns: returns a scaled scalar loss value "ret" which should be incorporated @@ -1262,12 +1272,11 @@ def forward(self, z = with_loss(z, ret, None) where z is any quantity that will be used in calculating the main loss. Ret will always be numerically equal to zero in the forward pass but - may behave as if it were nonzero for backprop purposes. + will behave as if it were nonzero for backprop purposes. """ max_product = float(self.max_product) - return MaxProductLossFunction.apply(x, y, max_product, - loss_scale / max_product, - self.name) + return CrossCosineLossFunction.apply(x, delta, mask, max_product, + loss_scale, self.name) class ChunkCausalDepthwiseConv1d(torch.nn.Module): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 447c909c74..daba4e26ca 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -42,6 +42,12 @@ ScaleLimiter, with_loss, ) +try: + from scaling import CrossCosineLoss +except: + pass + + from torch import Tensor, nn @@ -552,6 +558,8 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + self.cross_cosine_loss = CrossCosineLoss(max_product=0.2) + self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, @@ -651,6 +659,11 @@ def forward( self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) + src = with_loss(src, + self.cross_cosine_loss(src.permute(1, 0, 2), offset.permute(1, 0, 2), + aux_loss_scale, mask=src_key_padding_mask), + None) + src = self.scale_limiter(src, aux_loss_scale) From 9fda83af39422d33667a73fd4c75a335a10584cc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Oct 2025 10:16:12 +0800 Subject: [PATCH 0604/1191] Reduce max_product of CrossCosineLoss from 0.2 to 0.1. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index daba4e26ca..199dbe6c4d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -558,7 +558,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.cross_cosine_loss = CrossCosineLoss(max_product=0.2) + self.cross_cosine_loss = CrossCosineLoss(max_product=0.1) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, From 5b0128abe288a718447cb691637ae552d5c60e7a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Oct 2025 11:50:38 +0800 Subject: [PATCH 0605/1191] Change the thing limited in CrossCosineProduct so the dot product is computed with the half-way between before and after version of x,. --- egs/librispeech/ASR/zipformer/scaling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 74ff5706a1..4baac3d3b1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1193,8 +1193,11 @@ def forward(self, # cross cosine loss is for when you have a situation like: # x = x + delta # x = with_loss(x, cross_cosine_loss(x, delta)) -# and we want to make sure that delta only represents, on average, -# a small fraction of the total x. That is, mean(abs((delta . x) / (x . x))) < max_product +# and we want to make sure that adding delta does not change the magnitude +# of x very much, as a proportion of the total magnitude of x. +# we do this by making sure that delta . (x - delta/2) is close to zero, and +# normalize this by dividing by x's squared magnitude (x . x), i.e. +# we ensure that mean(abs((delta . (x - delta/2) / (x . x))) < max_product # you could also probably supply the original x and it would have a somewhat similar effect. class CrossCosineLossFunction(torch.autograd.Function): @staticmethod @@ -1226,7 +1229,7 @@ def backward(ctx, ans_grad): eps = 3.0e-08 # won't be zero in float16 - product = (x * delta).sum(dim=-1).abs() / ((x * x).sum(dim=-1) + eps) + product = ((x - 0.5 * delta) * delta).sum(dim=-1).abs() / ((x * x).sum(dim=-1) + eps) if mask is not None: product = product * (~mask).to(product.dtype) From 7da7dcb15451fd09c2d4853a3fba9b30698cf0f8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Oct 2025 13:48:23 +0800 Subject: [PATCH 0606/1191] Revert ExpNorm to actual ExpNorm, and introduce norm_change_loss with limit=0.2. --- egs/librispeech/ASR/zipformer/scaling.py | 77 +++++++++++----------- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++--- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4baac3d3b1..b04d56fe3b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -334,9 +334,8 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): - # plot (0.15 + 0.87*x**0.5)^(1/0.5), x+0.075 for 0 <= x <= 1 x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = ((0.15 + x_norm ** 0.5) ** 2).tanh() + num = (1 - (-x_norm).exp()) scales = num / x_norm scales = scale * scales return (x * scales) @@ -1191,70 +1190,68 @@ def forward(self, # cross cosine loss is for when you have a situation like: -# x = x + delta -# x = with_loss(x, cross_cosine_loss(x, delta)) +# y = y + delta +# y = with_loss(y, cross_cosine_loss(x, y, delta)) # and we want to make sure that adding delta does not change the magnitude -# of x very much, as a proportion of the total magnitude of x. -# we do this by making sure that delta . (x - delta/2) is close to zero, and -# normalize this by dividing by x's squared magnitude (x . x), i.e. -# we ensure that mean(abs((delta . (x - delta/2) / (x . x))) < max_product -# you could also probably supply the original x and it would have a somewhat similar effect. -class CrossCosineLossFunction(torch.autograd.Function): +# of individual embedding vectors very much. +# we do this by making sure that mean(abs(log(|x_i|) - log(|y_i|))) <= limit. +class NormChangeLossFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x: Tensor, delta: Tensor, mask: Optional[Tensor], - max_product: float, weight: float, name: str): - ctx.save_for_backward(x, delta) + def forward(ctx, x: Tensor, y: Tensor, mask: Optional[Tensor], + limit: float, weight: float, name: str): + ctx.save_for_backward(x, y) ctx.name = name ctx.mask = mask # mask will have no grad so it should be OK to store this way ctx.weight = weight - ctx.max_product = max_product + ctx.limit = limit # return fake loss that is always zero but behaves in backprop as if it were a real loss. return torch.tensor(0.0, device=x.device, dtype=x.dtype) @staticmethod @custom_bwd def backward(ctx, ans_grad): - x, delta = ctx.saved_tensors + x, y = ctx.saved_tensors name = ctx.name # str mask = ctx.mask # Tensor or None, shape: (batch_size, seq_len) weight = ctx.weight # float - max_product = ctx.max_product # float + limit = ctx.limit # float (batch_size, seq_len, num_channels) = x.shape with torch.enable_grad(): - x, delta = x.detach(), delta.detach() - x.requires_grad = True - delta.requires_grad = True - - eps = 3.0e-08 # won't be zero in float16 - - product = ((x - 0.5 * delta) * delta).sum(dim=-1).abs() / ((x * x).sum(dim=-1) + eps) + with torch.amp.autocast('cuda', enabled=False): + x, y = x.to(torch.float), y.to(torch.float) + x, y = x.detach(), y.detach() + x.requires_grad = True + y.requires_grad = True + eps = 1.0e-10 + x_sqnorm = (x * x).sum(dim=-1) + eps + y_sqnorm = (y * y).sum(dim=-1) + eps + norm_diff = 0.5 * (x_sqnorm.log() - y_sqnorm.log()).abs() - if mask is not None: - product = product * (~mask).to(product.dtype) + if mask is not None: + norm_diff = norm_diff * (~mask).to(norm_diff.dtype) - product = product.abs() - excess_product = (product.sum(dim=1) - seq_len * max_product).relu() + excess_norm_diff = (norm_diff.sum(dim=1) - seq_len * limit).relu() - if random.random() < 0.001: - logging.info(f"CrossCosineLoss: {name}, limit={max_product}, excess-product={excess_product.mean() / seq_len}") + if random.random() < 0.001: + logging.info(f"NormChangeLoss: {name}, limit={limit}, excess-norm-diff={excess_norm_diff.mean() / seq_len}") - grad = (weight * ans_grad).expand(excess_product.numel()) - excess_product.backward(grad) + grad = (weight * ans_grad).expand(excess_norm_diff.numel()) + excess_norm_diff.backward(grad) - return x.grad, delta.grad, None, None, None, None + return x.grad, y.grad, None, None, None, None -class CrossCosineLoss(nn.Module): +class NormChangeLoss(nn.Module): def __init__(self, - max_product: FloatLike): # e.g. 0.2. + limit: FloatLike): # e.g. 0.2. super().__init__() - self.max_product = max_product + self.limit = limit self.name = None def forward(self, x: Tensor, - delta: Tensor, + y: Tensor, loss_scale: float, mask: Optional[Tensor]) -> Tensor: """ @@ -1262,7 +1259,7 @@ def forward(self, x: Tensor of shape (batch_size, seq_len, num_channels) - delta: Tensor of shape (batch_size, seq_len, num_channels) + y: Tensor of shape (batch_size, seq_len, num_channels) loss_scale: the scale with which the loss should be incorporated into the graph. This should contain a factor of the grad_scale, if you are using GradScaler for automatic mixed precision training (amp). @@ -1277,9 +1274,9 @@ def forward(self, Ret will always be numerically equal to zero in the forward pass but will behave as if it were nonzero for backprop purposes. """ - max_product = float(self.max_product) - return CrossCosineLossFunction.apply(x, delta, mask, max_product, - loss_scale, self.name) + limit = float(self.limit) + return NormChangeLossFunction.apply(x, y, mask, limit, + loss_scale, self.name) class ChunkCausalDepthwiseConv1d(torch.nn.Module): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 199dbe6c4d..e818e94f29 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -43,7 +43,7 @@ with_loss, ) try: - from scaling import CrossCosineLoss + from scaling import NormChangeLoss except: pass @@ -558,7 +558,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.cross_cosine_loss = CrossCosineLoss(max_product=0.1) + self.norm_change_loss = NormChangeLoss(limit=0.2) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, @@ -647,8 +647,17 @@ def forward( residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale + src = src_orig + offset + src = with_loss(src, + self.norm_change_loss(src_orig.permute(1, 0, 2), src.permute(1, 0, 2), + aux_loss_scale, mask=src_key_padding_mask), + None) + + + + src = with_loss(src, self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) @@ -659,12 +668,6 @@ def forward( self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) - src = with_loss(src, - self.cross_cosine_loss(src.permute(1, 0, 2), offset.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask), - None) - - src = self.scale_limiter(src, aux_loss_scale) src = self.norm(src) From 4a3873f0a365d5a159b7a2da9d2df398ff235f17 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Oct 2025 16:32:23 +0800 Subject: [PATCH 0607/1191] Increase all the num layers by 1 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 86a95a5b00..cf95b03a25 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="4,6,7,7,7,6", + default="5,7,8,8,8,7", help="Number of zipformer encoder layers per stack, comma separated.", ) From ea9edb9b001579087d033ce324e8114bd0d07751 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Oct 2025 11:48:14 +0800 Subject: [PATCH 0608/1191] Increase min_rms values from 0.02 to 0.05. --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 248aa0df6e..e727228697 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -247,8 +247,8 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) + self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) - self.scale_limiter = ScaleLimiter(min_rms=0.02, max_rms=2.0) self.out_norm = ExpNorm(out_channels) def forward( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e818e94f29..fd353fc706 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -582,7 +582,7 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.02, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) self.norm = ExpNorm(embed_dim) From 998f40d4d88bb95e8490896b816d493faaa6edfb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Oct 2025 14:09:00 +0800 Subject: [PATCH 0609/1191] Increase min_rms values from 0.05 to 0.1 but penalize the min_deviation with a third power so that small violations of the limit are ignored. --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++++---- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b04d56fe3b..5b29ac56c4 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1466,16 +1466,18 @@ def backward(ctx, x_grad: Tensor): x = x.detach() x.requires_grad = True rms = (x ** 2).mean(dim=-1).sqrt() - max_deviation = (rms / ctx.max_rms - 1.).relu() - min_deviation = (1. - rms / ctx.min_rms).relu() + numel = rms.numel() + max_deviation = (rms / ctx.max_rms - 1.).relu().mean() + min_deviation = (1. - rms / ctx.min_rms).relu().mean() if random.random() < 0.002: logging.info( f"ScaleLimiter: name={ctx.name}, min_rms={ctx.min_rms}, max_rms={ctx.max_rms}, " - f"min_deviation={min_deviation.mean()}, max_deviation={max_deviation.mean()}, " + f"min_deviation={min_deviation.item()}, max_deviation={max_deviation.item()}, " f"loss_scale={ctx.aux_loss_scale}" ) - (min_deviation + max_deviation).backward(gradient=torch.full_like(min_deviation, ctx.aux_loss_scale)) + min_deviation = min_deviation ** 3 # strongly de-emphasize small violations of the minimum. + (min_deviation + max_deviation).backward(gradient=torch.full_like(min_deviation, ctx.aux_loss_scale * numel)) return x_grad + x.grad, None, None, None, None diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index e727228697..da8d700634 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -247,7 +247,7 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) - self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) self.out_norm = ExpNorm(out_channels) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fd353fc706..206fdbbb42 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -582,7 +582,7 @@ def __init__( if num_conv_modules >= 1: self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.05, max_rms=2.0) + self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) self.norm = ExpNorm(embed_dim) From ddc1e93d14260d60daa50bfc4db3530b070f236c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Oct 2025 22:41:41 +0800 Subject: [PATCH 0610/1191] Replace standard ExpNorm formula with ((0.15 + x ** 0.5) ** 2).tanh(). --- egs/librispeech/ASR/zipformer/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5b29ac56c4..4ff012e536 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -334,8 +334,9 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): + # plot (0.15 + 0.87*x**0.5)^(1/0.5), x+0.075 for 0 <= x <= 1 x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (1 - (-x_norm).exp()) + num = ((0.15 + x_norm ** 0.5) ** 2).tanh() scales = num / x_norm scales = scale * scales return (x * scales) From 7c4f415bf2ea8795c175393e95dd898cf222670a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Oct 2025 10:25:59 +0800 Subject: [PATCH 0611/1191] Increase limit in NormChangeLoss from 0.2 to 0.3. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 206fdbbb42..a6abf7f8e4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -558,7 +558,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.norm_change_loss = NormChangeLoss(limit=0.2) + self.norm_change_loss = NormChangeLoss(limit=0.3) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, From 1430b0e02babaa123443349d9c0cfb0acf891bde Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Oct 2025 15:18:25 +0800 Subject: [PATCH 0612/1191] Take the formula tanh(sqrt(0.1**2 + x**2)) from 1338, and the scale min=0.75. --- egs/librispeech/ASR/zipformer/scaling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4ff012e536..42943523af 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -334,9 +334,10 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): - # plot (0.15 + 0.87*x**0.5)^(1/0.5), x+0.075 for 0 <= x <= 1 - x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = ((0.15 + x_norm ** 0.5) ** 2).tanh() + eps = 0.1 + var = torch.mean(x ** 2, dim=channel_dim, keepdim=True) + x_norm, x_norm_witheps = var.sqrt(), (var + eps**2).sqrt() + num = x_norm_witheps.tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -428,7 +429,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.4, max=2.5, training=self.training) + self.scale, min=0.75, max=2.5, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, From 659ec65996282b39f282d6d482d96e27477ad7b2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Oct 2025 16:57:11 +0800 Subject: [PATCH 0613/1191] Code cleanup removing no-longer-used conv_module2; also remove self_attn3 and feed_forward3 --- egs/librispeech/ASR/zipformer/zipformer.py | 82 +++++++--------------- 1 file changed, 24 insertions(+), 58 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a6abf7f8e4..4b1d73d421 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -163,7 +163,6 @@ def _to_tuple(x): value_head_dim=value_head_dim[i], feedforward_multiple=feedforward_multiple[i], cnn_module_kernel=cnn_module_kernel[i], - num_conv_modules=1, causal=causal, ) @@ -349,8 +348,8 @@ def streaming_forward( A tensor of shape (batch_size,) containing the number of frames in `x` before padding. states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). + states[i*5:(i+1)*5] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv) src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. @@ -370,7 +369,7 @@ def streaming_forward( x, new_layer_states = module.streaming_forward( x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + states=states[layer_offset * 6 : (layer_offset + num_layers) * 5], left_context_len=self.left_context_frames[0] // ds, src_key_padding_mask=src_key_padding_mask[..., ::ds], ) @@ -398,7 +397,7 @@ def get_init_states( ) -> List[Tensor]: """Get initial states. - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + A list of cached tensors of all encoder layers. For layer-i, states[i*5:(i+1)*5] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). """ states = [] @@ -546,7 +545,6 @@ def __init__( value_head_dim: int, feedforward_multiple: int, cnn_module_kernel: int = 31, - num_conv_modules: int = 2, causal: bool = False, ) -> None: super(Zipformer2EncoderLayer, self).__init__() @@ -563,24 +561,20 @@ def __init__( self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, - num_heads=3 * num_heads, + num_heads=num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, ) - self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] + self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] feedforward_dim = embed_dim * feedforward_multiple self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4) self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim) - self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4) - if num_conv_modules >= 2: - self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - if num_conv_modules >= 1: - self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) @@ -625,7 +619,7 @@ def forward( key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale, ) - attn_weights1, attn_weights2, attn_weights3 = attn_weights.chunk(3, dim=0) + attn_weights1, attn_weights2 = attn_weights.chunk(2, dim=0) src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -633,18 +627,10 @@ def forward( src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - if hasattr(self, 'conv_module1'): - src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - if hasattr(self, 'conv_module2'): - src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - - src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale @@ -682,8 +668,7 @@ def streaming_forward( cached_nonlin_attn: Tensor, cached_val1: Tensor, cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, + cached_conv: Tensor, left_context_len: int, src_key_padding_mask: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -701,9 +686,7 @@ def streaming_forward( of shape (left_context_len, batch_size, value_dim) cached_val2: cached left context for the second attention module, of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, + cached_conv: cached left context for the first convolution module, of shape (batch_size, channels, left_pad) left_context_len: number of left context frames. src_key_padding_mask: the mask for padding, of shape @@ -716,8 +699,7 @@ def streaming_forward( - updated cached_nonlin_attn - updated cached_val1 - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 + - updated cached_conv """ src_orig = src @@ -749,9 +731,9 @@ def streaming_forward( ) src = src + self_attn - src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src_conv, cached_conv = self.conv_module.streaming_forward( src, - cache=cached_conv1, + cache=cached_conv, src_key_padding_mask=src_key_padding_mask[:, left_context_len:], ) src = src + src_conv @@ -767,18 +749,9 @@ def streaming_forward( ) src = src + self_attn - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm(src) + offset = (src - src_orig) * self.residual_scale - src = self.residual(src_orig, src) + src = src_orig + offset src = self.norm(src) @@ -788,8 +761,7 @@ def streaming_forward( cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, - cached_conv2, + cached_conv, ) @@ -807,8 +779,6 @@ class Zipformer2Encoder(nn.Module): >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) - - """ def __init__( self, @@ -934,8 +904,8 @@ def streaming_forward( Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states: list of cached tensors of N encoder layers. For layer-i, states[i*5:(i+1)*5] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv). left_context_len: Number of left context frames. src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); True means masked position. @@ -958,17 +928,15 @@ def streaming_forward( cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, - cached_conv2, - ) = states[i * 6 : (i + 1) * 6] + cached_conv, + ) = states[i * 5 : (i + 1) * 5] ( src, new_cached_key, new_cached_nonlin_attn, new_cached_val1, new_cached_val2, - new_cached_conv1, - new_cached_conv2, + new_cached_conv, ) = mod.streaming_forward( src, pos_emb, @@ -976,8 +944,7 @@ def streaming_forward( cached_nonlin_attn=cached_nonlin_attn, cached_val1=cached_val1, cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, + cached_conv=cached_conv, left_context_len=left_context_len, src_key_padding_mask=src_key_padding_mask, ) @@ -986,8 +953,7 @@ def streaming_forward( new_cached_nonlin_attn, new_cached_val1, new_cached_val2, - new_cached_conv1, - new_cached_conv2, + new_cached_conv, ] if num_channels > layer_dim: From 95e1171f5b781e2309ef17a8b3e912cab65b86af Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Oct 2025 16:59:17 +0800 Subject: [PATCH 0614/1191] Increase num zipformer layers in stacks by 1, or 2 for central 3 stacks. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index cf95b03a25..ce27ab2698 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="5,7,8,8,8,7", + default="6,8,10,10,10,8", help="Number of zipformer encoder layers per stack, comma separated.", ) From 1a7d53426b03c1bd8bbae20d02328834a71b6bad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Oct 2025 15:42:26 +0800 Subject: [PATCH 0615/1191] Move the SoftNorm to before the residual scaling. # Conflicts: # egs/librispeech/ASR/zipformer/zipformer.py --- egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4b1d73d421..132865fb54 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -631,6 +631,10 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = self.scale_limiter(src, aux_loss_scale) + + src = self.norm(src) + residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale @@ -654,10 +658,6 @@ def forward( self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), None) - src = self.scale_limiter(src, aux_loss_scale) - - src = self.norm(src) - return src def streaming_forward( From 9063881294bd3d81fe054eb847b8694e5e00b947 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Oct 2025 21:24:27 +0800 Subject: [PATCH 0616/1191] Increase num layers by 1 or 2. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ce27ab2698..9afc453572 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,10,10,10,8", + default="6,9,12,12,12,9", help="Number of zipformer encoder layers per stack, comma separated.", ) From 1be175fb3cfb20312c908ac8ee670ceccb679941 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Oct 2025 13:32:06 +0800 Subject: [PATCH 0617/1191] Change formula from tanh(sqrt(x^2 + 0.1^2)), min=0.75 to tanh(x+0.05), min=0.8 --- egs/librispeech/ASR/zipformer/scaling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 42943523af..5e9069a3dd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,9 +335,8 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): eps = 0.1 - var = torch.mean(x ** 2, dim=channel_dim, keepdim=True) - x_norm, x_norm_witheps = var.sqrt(), (var + eps**2).sqrt() - num = x_norm_witheps.tanh() + x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() + num = (x_norm + 0.05).tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -429,7 +428,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.75, max=2.5, training=self.training) + self.scale, min=0.8, max=2.5, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, From 2b93c58ee9218a286ae68eb8ff916af2c0abfbfc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Oct 2025 22:58:30 +0800 Subject: [PATCH 0618/1191] Change ExpNorm from tanh(x + 0.05), scale min=0.8, to tanh(x + 0.025), scale min=0.9 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5e9069a3dd..a506a50d4d 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -336,7 +336,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): eps = 0.1 x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (x_norm + 0.05).tanh() + num = (x_norm + 0.025).tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -428,7 +428,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.8, max=2.5, training=self.training) + self.scale, min=0.9, max=2.5, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, From 5ad03d9105fd42e8c5bb809afb620f2647ef331f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 10:47:50 +0800 Subject: [PATCH 0619/1191] Take the MaxVarLoss code from 1350, and introduce a MaxVarLoss before the norm module, with schedule decreasing to 0.25 at 30k batches. --- egs/librispeech/ASR/zipformer/scaling.py | 166 ++++++++++++++++----- egs/librispeech/ASR/zipformer/zipformer.py | 30 ++-- 2 files changed, 142 insertions(+), 54 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a506a50d4d..e8434056bd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -334,7 +334,6 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): - eps = 0.1 x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() num = (x_norm + 0.025).tanh() scales = num / x_norm @@ -829,6 +828,95 @@ def forward(self, x: Tensor): return ans +class MaxVarLossFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, mask: Optional[Tensor], max_var: float, weight: float, name: str): + ctx.save_for_backward(x) + if mask is not None: + assert mask.shape == x.shape[:2], (list(mask.shape), list(x.shape)) + ctx.mask = mask # mask will have no grad so it should be OK to store this way + ctx.name = name + ctx.weight = weight + ctx.max_var = max_var + return torch.tensor(0.0, device=x.device, dtype=x.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + x, = ctx.saved_tensors + mask = ctx.mask # optional Tensor + name = ctx.name # str + weight = ctx.weight # float + max_var = ctx.max_var # float + + + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + + eps = 3.0e-08 # won't be zero in float16 + x_var = (x ** 2).mean(dim=-1) + if mask is not None: + mask = (~mask).to(x.dtype) + x_var = x_var * mask + + with torch.amp.autocast('cuda', enabled=False): + x_var = x_var.to(torch.float) + if mask is not None: + numel = mask.sum() + else: + numel = x_var.numel() + excess_var = (x_var.sum() - max_var * numel).relu() + + if random.random() < 0.001: + logging.info(f"MaxVarLoss: {name}, limit={max_var}, excess-var={excess_var.mean() / numel}") + + # scale the loss by less than one, if we are close to the limit. + excess_var = excess_var * (excess_var / (numel * max_var)).clamp(max=1.0) + + # also add a factor of 1. / max_var into the loss scale. + excess_var.backward(gradient=torch.full_like(excess_var, weight * (1. / max_var))) + + return x.grad, None, None, None, None + + +class MaxVarLoss(nn.Module): + def __init__(self, + max_rms: FloatLike): + super().__init__() + self.max_rms = max_rms + self.name = None + + def forward(self, + x: Tensor, + loss_scale: float, + mask: Optional[Tensor] = None) -> Tensor: + """ + Compute loss that acts like a penalty if the mean-square value of x + exceeds self.max_rms**2 + + x: Tensor of shape (batch_size, seq_len, num_channels) + loss_scale: the scale with which the loss should be incorporated into the graph. + This should contain a factor of the grad_scale, if you are using GradScaler for + automatic mixed precision training (amp). + The loss will be summed over frames, and multiplied by this value. + mask: if supplied, mask of shape (batch_size, seq_len); + True means masked positions. + + Returns: + returns a scaled scalar loss value "ret" which should be incorporated + into the backprop graph by doing: + z = with_loss(z, ret, None) + where z is any quantity that will be used in calculating the main loss. + Ret will always be numerically equal to zero in the forward pass but + may behave as if it were nonzero for backprop purposes. + """ + return MaxVarLossFunction.apply(x, mask, + float(self.max_rms) ** 2, + loss_scale, self.name) + + class CosineSimilarityLossFunction(torch.autograd.Function): @staticmethod @custom_fwd @@ -883,6 +971,44 @@ def backward(ctx, ans_grad): return x.grad, None, None, None, None +class CosineSimilarityLoss(nn.Module): + def __init__(self, + max_similarity: FloatLike): # e.g. 0.1 for max_similarity + super().__init__() + self.max_similarity = max_similarity + self.name = None + + def forward(self, + x: Tensor, + loss_scale: float, + mask: Optional[Tensor] = None) -> Tensor: + """ + Compute cosine-similarity loss that tries to make sure distinct output vectors + have inner products with small magnitude (after normalization), i.e. the cosine + of the angle between should be close to zero. + + x: Tensor of shape (batch_size, seq_len, num_channels) + loss_scale: the scale with which the loss should be incorporated into the graph. + This should contain a factor of the grad_scale, if you are using GradScaler for + automatic mixed precision training (amp). + The loss will be summed over frames, and multiplied by this value. + mask: if supplied, mask of shape (batch_size, seq_len); + True means masked positions. + + Returns: + returns a scaled scalar loss value "ret" which should be incorporated + into the backprop graph by doing: + z = with_loss(z, ret, None) + where z is any quantity that will be used in calculating the main loss. + Ret will always be numerically equal to zero in the forward pass but + may behave as if it were nonzero for backprop purposes. + """ + return CosineSimilarityLossFunction.apply(x, mask, + float(self.max_similarity), + loss_scale, self.name) + + + class SimpleOrthogonalPenaltyFunction(torch.autograd.Function): @staticmethod @custom_fwd @@ -1025,42 +1151,6 @@ def get_max_similarity(rank: int, power: float): """ return (0.7978845608 / (rank ** 0.5)) ** power -class CosineSimilarityLoss(nn.Module): - def __init__(self, - max_similarity: FloatLike): # e.g. 0.1 for max_similarity - super().__init__() - self.max_similarity = max_similarity - self.name = None - - def forward(self, - x: Tensor, - loss_scale: float, - mask: Optional[Tensor] = None) -> Tensor: - """ - Compute cosine-similarity loss that tries to make sure distinct output vectors - have inner products with small magnitude (after normalization), i.e. the cosine - of the angle between should be close to zero. - - x: Tensor of shape (batch_size, seq_len, num_channels) - loss_scale: the scale with which the loss should be incorporated into the graph. - This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). - The loss will be summed over frames, and multiplied by this value. - mask: if supplied, mask of shape (batch_size, seq_len); - True means masked positions. - - Returns: - returns a scaled scalar loss value "ret" which should be incorporated - into the backprop graph by doing: - z = with_loss(z, ret, None) - where z is any quantity that will be used in calculating the main loss. - Ret will always be numerically equal to zero in the forward pass but - may behave as if it were nonzero for backprop purposes. - """ - return CosineSimilarityLossFunction.apply(x, mask, - float(self.max_similarity), - loss_scale, self.name) - class MinProductLossFunction(torch.autograd.Function): @staticmethod @@ -1712,7 +1802,7 @@ def backward(ctx, ans_grad: Tensor): ) -def with_loss(x, y, name): +def with_loss(x, y, name=None): # returns x but adds y.sum() to the loss function. return WithLoss.apply(x, y, name) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 132865fb54..2a0dda3650 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -34,6 +34,7 @@ ExpNorm, ChunkCausalDepthwiseConv1d, CosineSimilarityLoss, + ScheduledFloat, FloatLike, convert_num_channels, limit_param_value, @@ -44,6 +45,7 @@ ) try: from scaling import NormChangeLoss + from scaling import MaxVarLoss except: pass @@ -555,8 +557,9 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.norm_change_loss = NormChangeLoss(limit=0.3) + self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 1.2), (30000.0, 0.25), default=1.0)) + self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, @@ -633,6 +636,9 @@ def forward( src = self.scale_limiter(src, aux_loss_scale) + src = with_loss(src, + self.max_var_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + src = self.norm(src) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) @@ -645,18 +651,13 @@ def forward( aux_loss_scale, mask=src_key_padding_mask), None) - - - src = with_loss(src, - self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), - None) + self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) # also put cosine_loss on src, mostly because it will be used in scale_limiter and we don't want the # network to get around the scale limitation by using an offset. src = with_loss(src, - self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), - None) + self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) return src @@ -884,8 +885,7 @@ def forward( self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask) + self.cosine_loss(src.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask), - None) + aux_loss_scale, src_key_padding_mask)) if hasattr(self, 'out_proj'): src = self.out_proj(src) @@ -1273,8 +1273,7 @@ def forward( k = with_loss(k, self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), aux_loss_scale / num_heads, - key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None), - None) + key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None)) # time1 refers to target, time2 refers to source. @@ -1574,7 +1573,7 @@ def forward( if aux_loss_scale: x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, - mask=src_key_padding_mask), None) + mask=src_key_padding_mask)) return x @@ -1662,7 +1661,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int): def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: x = self.in_proj(x) x = self.out_proj(x) - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask)) return x @@ -1782,8 +1781,7 @@ def forward( x = self.out_proj(x) # (time, batch, channels) - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), - None) + x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask)) return x From 103ba0b0a18fadf721adc2903829c3c795c66313 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 11:16:43 +0800 Subject: [PATCH 0620/1191] Increase initial max_rms to make sure it is not active at the start. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0dda3650..8f3170513b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -558,7 +558,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.norm_change_loss = NormChangeLoss(limit=0.3) - self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 1.2), (30000.0, 0.25), default=1.0)) + self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 2.0), (30000.0, 0.25), default=1.0)) self.self_attn_weights = RelPositionMultiheadAttentionWeights( From b6f4a56722525112f654b34f1a285ee590e65c33 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 12:29:56 +0800 Subject: [PATCH 0621/1191] Revert the 1353->1355 change to tanh formula, i.e. revert tanh(x + 0.025),min=0.9 to tanh(x + 0.05),min=0.8 --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e8434056bd..c0ed9c2ac3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,7 +335,7 @@ def backward(ctx, x_grad, *args): def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (x_norm + 0.025).tanh() + num = (x_norm + 0.05).tanh() scales = num / x_norm scales = scale * scales return (x * scales) @@ -427,7 +427,7 @@ def forward(self, x: Tensor) -> Tensor: return _exp_norm(x, self.scale, self.channel_dim) scale = limit_param_value( - self.scale, min=0.9, max=2.5, training=self.training) + self.scale, min=0.8, max=2.5, training=self.training) ans = ExpNormFunction.apply( x, scale, self.channel_dim, From 94ebfef50757bce6665707cc6aad84cea7eef9ab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 13:21:57 +0800 Subject: [PATCH 0622/1191] Increase min of final residual scale from 0.1 to 0.25. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8f3170513b..487d0d5cc7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -857,6 +857,7 @@ def forward( residual_scale = limit_param_value(self.residual_scales[0], min=-1.0, max=-0.5) + src_with_bypass = residual_scale * src for i, mod in enumerate(self.layers): @@ -869,7 +870,7 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.0 if i + 1 < num_layers else 0.1, + min=0.0 if i + 1 < num_layers else 0.25, max=1.0) src_with_bypass = src_with_bypass + residual_scale * src From e115976fd8044b01a4c1237fcf4ed2029935d382 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 16:21:17 +0800 Subject: [PATCH 0623/1191] Have max var loss be applied on the diff, and reduce schedule. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 487d0d5cc7..21788e6408 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -558,7 +558,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.norm_change_loss = NormChangeLoss(limit=0.3) - self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 2.0), (30000.0, 0.25), default=1.0)) + self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.15), default=1.0)) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -637,7 +637,7 @@ def forward( src = self.scale_limiter(src, aux_loss_scale) src = with_loss(src, - self.max_var_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + self.max_var_loss((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) src = self.norm(src) From 0ca3f4fc55df97325845e6aa4ab4e06bfde0fc9a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 17:26:55 +0800 Subject: [PATCH 0624/1191] Scale up encoder_pos features by 4. --- egs/librispeech/ASR/zipformer/zipformer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 21788e6408..bb293546b5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -795,8 +795,16 @@ def __init__( self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) self.proj.lr_scale = 0.75 + # scale up the position weights, this is to fix an issue with the + # linear_pos projections otherwise needing to have too-large scale, larger + # than the "default scale" used in AdamW-like + # log-weight decay in TransformedAdam. The issue we are trying + # to solve is that between different runs, the linear_pos projections of + # different self_attn_weights modules get very different scales.. the + # thinking is that sometimes if one of these linear_pos projections has + # a too-small scale, it never "learns something useful". self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, length_factor=1.0 + pos_dim, length_factor=1.0, feat_scale=4, ) self.name = None self.layers = nn.ModuleList( @@ -1030,6 +1038,7 @@ def __init__( embed_dim: int, max_len: int = 1000, length_factor: float = 1.0, + feat_scale: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -1038,6 +1047,7 @@ def __init__( self.pe = None assert length_factor >= 1.0, length_factor self.length_factor = length_factor + self.feat_scale = feat_scale self.extend_pe(torch.tensor(0.0).expand(max_len)) def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: @@ -1088,7 +1098,7 @@ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: pe[:, 1::2] = sines pe[:, -1] = 1.0 # for bias. - self.pe = pe.to(dtype=x.dtype) + self.pe = pe.to(dtype=x.dtype) * self.feat_scale def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: """Create positional encoding. @@ -1209,8 +1219,6 @@ def __init__( self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) - - # linear transformation for positional encoding. self.linear_pos = ScaledLinear( pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 ) From 6985c5cc8a3edfafdf75d11c86647c5e9e65fe6c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Oct 2025 21:01:01 +0800 Subject: [PATCH 0625/1191] Move norm to end of zipformer layer, after residual scale; but keep max_var_loss where it was. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bb293546b5..a2291e76c9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -639,8 +639,6 @@ def forward( src = with_loss(src, self.max_var_loss((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) - src = self.norm(src) - residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale @@ -659,6 +657,8 @@ def forward( src = with_loss(src, self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + src = self.norm(src) + return src def streaming_forward( From feddb238238291061b422250f9e461f0f82b75dc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 11:16:17 +0800 Subject: [PATCH 0626/1191] Tie the position encoding query with the first few dims of the query. --- egs/librispeech/ASR/zipformer/zipformer.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a2291e76c9..dba1afa55e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1200,11 +1200,12 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim + assert pos_head_dim <= query_head_dim self.pos_head_dim = pos_head_dim self.name = None # will be overwritten in training code; for diagnostics. key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + in_proj_dim = (query_head_dim + key_head_dim) * num_heads # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5 that has been used in previous forms of attention, @@ -1262,21 +1263,15 @@ def forward( # self-attention q = x[..., 0:query_dim] k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, ( - p.shape[-1], - num_heads, - pos_head_dim, - ) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.copy_key(k) - p = self.copy_pos_query(p) + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = q[..., :pos_head_dim] + p = self.copy_pos_query(p) # diagnostics only if aux_loss_scale: k = with_loss(k, From 2b627feec1460f621ce9cca1092fec79a445f3a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 13:24:24 +0800 Subject: [PATCH 0627/1191] Remove min_rms of ScaleLimiter; apply scale_limiter in zipformer layer only to the offset (prior to residual_scale multiplication), with limit 0.25. --- egs/librispeech/ASR/zipformer/scaling.py | 21 +++++++++----------- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c0ed9c2ac3..2067820ef1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1540,9 +1540,8 @@ def streaming_forward( class ScaleLimiterFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, min_rms: float, max_rms: float, aux_loss_scale: float, name: str): + def forward(ctx, x: Tensor, max_rms: float, aux_loss_scale: float, name: str): ctx.save_for_backward(x) - ctx.min_rms = min_rms ctx.max_rms = max_rms ctx.aux_loss_scale = aux_loss_scale ctx.name = name @@ -1559,17 +1558,16 @@ def backward(ctx, x_grad: Tensor): rms = (x ** 2).mean(dim=-1).sqrt() numel = rms.numel() - max_deviation = (rms / ctx.max_rms - 1.).relu().mean() - min_deviation = (1. - rms / ctx.min_rms).relu().mean() + excess = (rms / ctx.max_rms - 1.).relu().mean() + if random.random() < 0.002: logging.info( - f"ScaleLimiter: name={ctx.name}, min_rms={ctx.min_rms}, max_rms={ctx.max_rms}, " - f"min_deviation={min_deviation.item()}, max_deviation={max_deviation.item()}, " + f"ScaleLimiter: name={ctx.name}, max_rms={ctx.max_rms}, " + f"excess={excess.item()}, " f"loss_scale={ctx.aux_loss_scale}" ) - min_deviation = min_deviation ** 3 # strongly de-emphasize small violations of the minimum. - (min_deviation + max_deviation).backward(gradient=torch.full_like(min_deviation, ctx.aux_loss_scale * numel)) - return x_grad + x.grad, None, None, None, None + excess.backward(gradient=torch.full_like(excess, ctx.aux_loss_scale * numel)) + return x_grad + x.grad, None, None, None class ScaleLimiter(torch.nn.Module): @@ -1579,10 +1577,9 @@ class ScaleLimiter(torch.nn.Module): Assumes channel dim is -1 and the input shape has >1 dimension. """ - def __init__(self, min_rms: FloatLike, max_rms: FloatLike): + def __init__(self, max_rms: FloatLike): super().__init__() self.name = None - self.min_rms = min_rms self.max_rms = max_rms @@ -1590,7 +1587,7 @@ def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return _no_op(x) else: - return ScaleLimiterFunction.apply(x, float(self.min_rms), float(self.max_rms), + return ScaleLimiterFunction.apply(x, float(self.max_rms), aux_loss_scale, self.name) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index da8d700634..0b95ce6856 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -247,7 +247,7 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) - self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) + self.scale_limiter = ScaleLimiter(max_rms=2.0) self.out_norm = ExpNorm(out_channels) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index dba1afa55e..8f36969194 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -579,7 +579,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.scale_limiter = ScaleLimiter(min_rms=0.1, max_rms=2.0) + self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) self.norm = ExpNorm(embed_dim) @@ -634,14 +634,14 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = self.scale_limiter(src, aux_loss_scale) - src = with_loss(src, self.max_var_loss((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale + offset = self.offset_scale_limiter(offset, 0.1 * aux_loss_scale) + src = src_orig + offset src = with_loss(src, From 36b2368106e10df81d2071b9f4bfc22ebb95e6d6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 14:16:13 +0800 Subject: [PATCH 0628/1191] Add 0.02 * torch.randn_like(x) to the features in training. --- egs/librispeech/ASR/zapformer/model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index ed8cbc2cf8..bcc1b4fedc 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -150,17 +150,15 @@ def forward_encoder( encoder_out_lens: Encoder output lengths, of shape (N,). """ - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) + + if self.training: + noise_scale = 1.0e-02 + x = x + noise_scale * torch.rand_like(x) x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - src_key_padding_mask = make_pad_mask(x_lens) # (N, T) - specaug_mask = specaug_mask[:, ::2] - assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 - specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) From c69d7912bfe9828ccc9e5989e35e046dc414fd0e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 15:39:27 +0800 Subject: [PATCH 0629/1191] Reduce scale_default from 0.1 to 0.05 but make it 4.0 for linear_pos modules to bias their scale upwards. Remove feat_scale of 4 on CompactRelPositionEncoding. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 9 +++++---- icefall/utils.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 376d4a5aa2..3200be7092 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -377,7 +377,7 @@ def __init__( direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, - scale_default=0.2, + scale_default=0.05, scalar_lr_scale=0.1, scaling_lr_scale=0.1, eps=1.0e-08, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8f36969194..7cb5ce129c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -804,7 +804,7 @@ def __init__( # thinking is that sometimes if one of these linear_pos projections has # a too-small scale, it never "learns something useful". self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, length_factor=1.0, feat_scale=4, + pos_dim, length_factor=1.0, ) self.name = None self.layers = nn.ModuleList( @@ -1038,7 +1038,6 @@ def __init__( embed_dim: int, max_len: int = 1000, length_factor: float = 1.0, - feat_scale: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -1047,7 +1046,6 @@ def __init__( self.pe = None assert length_factor >= 1.0, length_factor self.length_factor = length_factor - self.feat_scale = feat_scale self.extend_pe(torch.tensor(0.0).expand(max_len)) def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: @@ -1098,7 +1096,7 @@ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: pe[:, 1::2] = sines pe[:, -1] = 1.0 # for bias. - self.pe = pe.to(dtype=x.dtype) * self.feat_scale + self.pe = pe.to(dtype=x.dtype) def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: """Create positional encoding. @@ -1223,6 +1221,9 @@ def __init__( self.linear_pos = ScaledLinear( pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 ) + # the next line will give an upward bias to the scale of linear_pos, nudging it towards + # a state where it gives a meaningful contribution to the stores. + self.linear_pos.scale_default = 10.0 # the following are for diagnostics only, see --print-diagnostics option self.copy_pos_query = Identity() diff --git a/icefall/utils.py b/icefall/utils.py index e69ab8cd05..e523a2e546 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1474,7 +1474,7 @@ def get_parameter_groups_with_lrs( lr: float, include_names: bool = False, freeze_modules: List[str] = [], - attrs: List[str] = ['lr_scale', 'weight_min_rms', 'bias_min_rms', 'weight_max_rms', 'bias_max_rms'], + attrs: List[str] = ['lr_scale', 'weight_min_rms', 'bias_min_rms', 'weight_max_rms', 'bias_max_rms', 'scale_default'], ) -> List[dict]: """ This is to automatically create parameter-groups with overrides of parameter optimizer From f221d164c331f1044f7fdb98561e6e85c2e8bd7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 17:06:50 +0800 Subject: [PATCH 0630/1191] Increase max_rms of offset scale limiter from 0.25 to 0.2, halve scale on aux_loss_scale given to it to 0.05 --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8f36969194..c032d422d1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -579,7 +579,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) + self.offset_scale_limiter = ScaleLimiter(max_rms=0.2) self.norm = ExpNorm(embed_dim) @@ -640,7 +640,7 @@ def forward( residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale - offset = self.offset_scale_limiter(offset, 0.1 * aux_loss_scale) + offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) src = src_orig + offset From f28d32046aa405ae7c69848af4e39857d5481907 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 17:52:34 +0800 Subject: [PATCH 0631/1191] Move adding noise to inside Conv2dSubsampling and remove it from model.py --- egs/librispeech/ASR/zapformer/model.py | 5 --- egs/librispeech/ASR/zipformer/subsampling.py | 47 ++++++++------------ 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index bcc1b4fedc..a6aced14d9 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -150,11 +150,6 @@ def forward_encoder( encoder_out_lens: Encoder output lengths, of shape (N,). """ - - if self.training: - noise_scale = 1.0e-02 - x = x + noise_scale * torch.rand_like(x) - x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 0b95ce6856..44bc01c0c9 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -25,6 +25,7 @@ ScaledLinear, ExpNorm, FloatLike, + get_max_similarity, ScaledConv2d, ScaleGrad, ScheduledFloat, @@ -35,32 +36,20 @@ ) from torch import Tensor, nn -# TEMP: put this here, eventually we should import from scaling.py -def get_max_similarity(rank: int, power: float): - """ - For use when initializing CosineSimilarityLoss, this returns a value for - the "max_similarity" argument. - max_similarity is an upper limit we impose on the mean value of (x_i . x_j), - where i != j are two different sequence-position indexes and x_i and x_j are - activation vectors normalized to have unit length. - - rank: the dimension of the space, usually this is the num_channels, but if - we have just up-projected from a bottleneck, it would be the bottleneck - dimension. - power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean - we enforce the vector dimensions to be completely independent like Gaussian noise - (don't do this); if we set power=0.0 it would be equivalent to not having - the CosineSimilarityLoss at all. - - The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal - variable. If x consists of independent Gaussian noise of dimension D, with - variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" - would be close to a no-op for large D), then (x_i . x_j) would be distributed as - a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) - would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value - between this and 1, as a kind of heuristic limit on this max_similarity. - """ - return (0.7978845608 / (rank ** 0.5)) ** power + +class AddNoise(nn.Module): + # assume Conv2d-style input: (N, C, H, W) + def __init__(self, rel_noise_scale: float): + super().__init__() + self.rel_noise_scale = rel_noise_scale + + def forward(self, x: Tensor) -> Tensor: + if not self.training: + return x + eps = 3.0e-08 + noise_scale = ((x ** 2).mean(dim=(1,2,3), keepdim=True) + eps).sqrt() * self.rel_noise_scale + return x + noise_scale * torch.randn_like(x) + class ConvNeXt(nn.Module): @@ -178,8 +167,8 @@ def __init__( self, in_channels: int, out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, + layer1_channels: int = 16, + layer2_channels: int = 64, layer3_channels: int = 128, ) -> None: """ @@ -213,7 +202,7 @@ def __init__( kernel_size=3, padding=(0, 1), # (time, freq) ), - ScaleGrad(0.2), + AddNoise(rel_noise_scale=5.0e-03), SwashR(), nn.Conv2d( in_channels=layer1_channels, From 972ea3dc1740fdce9305da04fefd54bc64c63de2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Oct 2025 16:51:44 +0800 Subject: [PATCH 0632/1191] Remove norm change loss. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 95ff2d4985..dd38cf2874 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -557,7 +557,6 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.norm_change_loss = NormChangeLoss(limit=0.3) self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.15), default=1.0)) @@ -644,11 +643,6 @@ def forward( src = src_orig + offset - src = with_loss(src, - self.norm_change_loss(src_orig.permute(1, 0, 2), src.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask), - None) - src = with_loss(src, self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) From a8fb5091f6466419350620158413b09bf3e04e15 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Oct 2025 11:18:41 +0800 Subject: [PATCH 0633/1191] Increase max_rms of mx_var_loss from 0.2 to 0.25 --- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index dd38cf2874..dc4ebb7544 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -558,6 +558,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.15), default=1.0)) + self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -578,8 +579,6 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.offset_scale_limiter = ScaleLimiter(max_rms=0.2) - self.norm = ExpNorm(embed_dim) From b00d42a9e37a121afb0383d4939b0c9c7ed7938f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Oct 2025 12:56:35 +0800 Subject: [PATCH 0634/1191] Add a second max_var_loss, on the scaled offset, with limit 0.1, and increase the limit on the unscaled offset from 0.15 to 0.2. --- egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index dc4ebb7544..0aa00c078e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -557,7 +557,8 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.15), default=1.0)) + self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.15), default=1.0)) + self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) @@ -633,11 +634,14 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = with_loss(src, - self.max_var_loss((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + self.max_var_loss1((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale + offset = with_loss(offset, + self.max_var_loss2(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) src = src_orig + offset From ce3a3271a563bddd29491774e1ed3ae0e10d7435 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Oct 2025 13:03:42 +0800 Subject: [PATCH 0635/1191] Meant to include this last commit --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0aa00c078e..d392f2d044 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -557,7 +557,7 @@ def __init__( self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.15), default=1.0)) + self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) From c26b142efa7dc2b3fa3255939f939e0620ed943d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Oct 2025 22:46:15 +0800 Subject: [PATCH 0636/1191] Fix issue RE log_scale_default->scale_default so config is actually specified. --- egs/librispeech/ASR/zipformer/optim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3200be7092..6c01fbdff1 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -220,7 +220,7 @@ def reverse_transform_param(group, p, orig_shape): # Apply weight-decay of log_scale, similar to weight decay of AdamW, except it regresses the # log-scale to a default value instead of regressing the scale towards zero. - log_scale_default = group["log_scale_default"] + log_scale_default = math.log(group["scale_default"]) log_scale = ((log_scale - log_scale_default) * (1. - group["scale_decay"] * scaling_lr)) + log_scale_default scale = log_scale.exp().clamp(min=min_scale, max=max_scale) @@ -396,7 +396,7 @@ def __init__( direct=direct, beta2=beta2, scale_decay=scale_decay, - log_scale_default=math.log(scale_default), + scale_default=scale_default, scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, @@ -878,7 +878,7 @@ def __init__( direct=direct, beta2=beta2, scale_decay=scale_decay, - log_scale_default=math.log(scale_default), + scale_default=scale_default, scalar_lr_scale=scalar_lr_scale, scaling_lr_scale=scaling_lr_scale, eps=eps, @@ -1529,7 +1529,7 @@ def _test_transformed_adam(hidden_dim: int): def _test_transform_params(): # caution: this has occasional errors. group = { "bias_min_scale": 0.001, "weight_min_scale": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, - "log_scale_default": 0.05, "scale_decay": 0.01, + "scale_default": 0.05, "scale_decay": 0.01, "weight_max_scale": 20.0, "bias_max_scale": 20.0, "lr": 0.0} # lr set to 0.0 so weight-scale decay does not happen. for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: for shape in [ (1, 1), (2, 1), (2, 2), (2, 3, 4), (3, 10, 20), (4,) ]: From c2ee8043767a89fcd8b5ca77b88e770e4657a145 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 19 Oct 2025 12:44:26 +0800 Subject: [PATCH 0637/1191] Remove scale_default override of linear_pos --- egs/librispeech/ASR/zipformer/zipformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d392f2d044..2a93b23739 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1218,9 +1218,6 @@ def __init__( self.linear_pos = ScaledLinear( pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 ) - # the next line will give an upward bias to the scale of linear_pos, nudging it towards - # a state where it gives a meaningful contribution to the stores. - self.linear_pos.scale_default = 10.0 # the following are for diagnostics only, see --print-diagnostics option self.copy_pos_query = Identity() From a8223d75bd303beb9d445a506e835d6658a7c29b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 19 Oct 2025 13:07:16 +0800 Subject: [PATCH 0638/1191] Simple cleanup: remove unused Zipformer2Encoder out_proj --- egs/librispeech/ASR/zipformer/zipformer.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a93b23739..3e0ea7a401 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -175,7 +175,6 @@ def _to_tuple(x): num_encoder_layers[i], dim=downsampling_factor[i]*input_dim, pos_dim=pos_dim, - out_proj=False, # (downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], ) encoders.append(encoder) @@ -784,7 +783,6 @@ def __init__( num_layers: int, dim: int, pos_dim: int, - out_proj: bool, ) -> None: super().__init__() @@ -819,11 +817,6 @@ def __init__( self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) - # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear - # module. - if out_proj: - self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) - self.out_proj.lr_scale = 0.75 def forward( @@ -893,9 +886,6 @@ def forward( self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask)) - if hasattr(self, 'out_proj'): - src = self.out_proj(src) - return src From 220fe7e9238e4ea0b6cb94436c6a25392a73a673 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 19 Oct 2025 13:20:30 +0800 Subject: [PATCH 0639/1191] Remove reconstruction loss from the code. --- egs/librispeech/ASR/zapformer/model.py | 90 +------------------------- egs/librispeech/ASR/zapformer/train.py | 18 +----- 2 files changed, 4 insertions(+), 104 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index a6aced14d9..f6735a65c4 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -123,9 +123,6 @@ def __init__( else: assert attention_decoder is None - self.reconstruction_proj = ScaledLinear( - encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) - def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, @@ -524,89 +521,4 @@ def forward( else: attention_decoder_loss = torch.empty(0) - reconstruction_loss = self.forward_reconstruction_loss(self.gauss_norm(x_no_specaug, x_lens), - encoder_out, encoder_out_lens) - - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss - - - - def gauss_norm(self, - log_mels: Tensor, - log_mel_lens: Tensor) -> Tensor: - (batch_size, seq_len, num_channels) = log_mels.shape - - rand_pos = torch.randint(100000000, (batch_size, seq_len, num_channels), device=log_mels.device) - rand_pos = rand_pos % log_mel_lens.unsqueeze(-1).unsqueeze(-1) - arange = torch.arange(seq_len, device=log_mels.device)[None, :, None].expand_as(rand_pos) - length_mask = make_pad_mask(log_mel_lens) # True in masked positions - - # select the "self" position if we are in the non-masked region; select random - # non-masked positions when in padding regions. - length_mask = length_mask.unsqueeze(-1).expand_as(log_mels) - log_mels = torch.gather(log_mels, dim=1, index=torch.where(length_mask, rand_pos, arange)) - - values, indexes = log_mels.sort(dim=1) # sort on seq dim - N = max(2, log_mels.shape[1]) - norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, log_mels.shape[1], device=log_mels.device, dtype=torch.float) - norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data - norm_rank = norm_rank.reshape(1, -1, 1) - norm_rank = norm_rank.repeat(log_mels.shape[0], 1, log_mels.shape[2]) - log_mels_norm = torch.empty_like(log_mels) - log_mels_norm.scatter_(dim=1, index=indexes, src=norm_rank) - return log_mels_norm - - - def forward_reconstruction_loss(self, - log_mels: Tensor, - encoder_out: Tensor, - encoder_out_lens: Tensor): - """ - Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If - use_cr_ctc then we swap the first and second halves of the batch. - - Args: - log_mels: log-mel features of shape (batch_size, T, num_mels) - encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) - """ - batch_size = log_mels.shape[0] - num_mels = log_mels.shape[2] - - pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) - T_embed = pred_mels.shape[1] - pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) - - excess_frames = log_mels.shape[1] - pred_mels.shape[1] - assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. - - T = pred_mels.shape[1] - offset = 3 # i found excess_frames = 5 one time. - log_mels = log_mels[:, offset:offset+T] - - lens = encoder_out_lens * 4 - pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions - assert pad_mask.shape == (batch_size, T) - pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position - # padd_mask: (batch_size, T, 1) - - - # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly - # helps to down-weight the effect of very silent silences. - #loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, - # reduction='none', beta=1.0) - # this way of applying the padding mask is not really ideal in terms of normalization, - # it will cause us to under-normalize a bit. - diff = (log_mels - pred_mels) * pad_mask - - loss = (diff ** 2) - - # removing the masking logic since we now use the no-specaug reference sequence. - ## masking. if it's different from the next item on both the frequency dim - ## and the time dim, it means we are in neither a time masked nor a frequency masked - ## position. - #mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), - # log_mels != torch.roll(log_mels, 1, dims=1)) - #loss = loss * mask.to(loss.dtype) - - loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. - return loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 9afc453572..7f73b64115 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -536,13 +536,6 @@ def get_parser(): help="Scale for consistency-regularization loss.", ) - parser.add_argument( - "--reconstruction-loss-scale", - type=float, - default=0.005, - help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", - ) - parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -666,8 +659,8 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - warm_step: The warmup period that dictates the decay of the - scale on pruned loss (for transducer) and the reconstruction and prediction - losses. Expressed in terms of the "adjusted batch count", i.e. the + scale on pruned loss (for transducer). + Expressed in terms of the "adjusted batch count", i.e. the normalized batch count after adjusting for changes in batch size. """ params = AttributeDict( @@ -988,7 +981,7 @@ def compute_loss( spec_augment = None # disable spec-aug with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -1022,10 +1015,6 @@ def warmup_schedule(scale, initial_factor): if num_copies > 1: loss += params.cr_loss_scale * cr_loss - reconstruction_loss_scale = params.reconstruction_loss_scale - - loss += reconstruction_loss_scale * reconstruction_loss - if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -1048,7 +1037,6 @@ def warmup_schedule(scale, initial_factor): info["ctc_loss"] = ctc_loss.detach().cpu().item() if num_copies > 1: info["cr_loss"] = cr_loss.detach().cpu().item() - info["recon_loss"] = reconstruction_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() From e3aa4405d23f8bb1e7a1afc1ddef503e76feb7bc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Oct 2025 12:51:24 +0800 Subject: [PATCH 0640/1191] Increase attn_scores_limit from 8.0 to 12.0. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3e0ea7a401..e516ef27c5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1338,7 +1338,7 @@ def forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. + attn_scores_limit = 12.0 # limit on our metric that affects how much grad we are likely to backpropagate. attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, self.name) From ff43f485890707926a546d023aec8ea370c08250 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 10:30:57 +0800 Subject: [PATCH 0641/1191] Revert attn_scores_limit to 8.0; decrease its aux_loss_scale by factor of 10; apply mask before penalizing. --- egs/librispeech/ASR/zipformer/zipformer.py | 26 ++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e516ef27c5..33ac45e322 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1117,9 +1117,11 @@ def forward( attn_scores: Tensor, limit: float, aux_loss_scale: float, + key_padding_mask: Optional[Tensor], name: str): # attn_scores: (head, batch, query_time, key_time) ctx.save_for_backward(attn_scores) + ctx.mask = key_padding_mask # has no grad ctx.limit = limit ctx.aux_loss_scale = aux_loss_scale ctx.name = name @@ -1130,26 +1132,31 @@ def backward( ctx, attn_scores_grad): attn_scores, = ctx.saved_tensors + mask = ctx.mask (num_heads, batch_size, seq_len, _) = attn_scores.shape with torch.amp.autocast('cuda', enabled=False): attn_scores = attn_scores.to(torch.float) attn_scores = attn_scores.detach() + # attn_scores: (head, batch, query_time, key_time) attn_scores.requires_grad = True with torch.enable_grad(): probs = attn_scores.softmax(dim=-1) - # attn_scores: (head, batch, query_time, key_time) scaled_scores = attn_scores.abs() * probs - query_scores = (scaled_scores.sum(dim=-1) - ctx.limit).relu() - - if random.random() < 0.001: - query_excess = query_scores.mean(dim=(1,2)) - logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, query_excess={query_excess}") + avg_scores = scaled_scores.sum(dim=-1) # (head, batch, query_time) + if mask is not None: + avg_scores = avg_scores * (~mask) # mask: (batch, time) + query_scores = (avg_scores - ctx.limit).relu() + + if random.random() < 0.0005: + query_excess = query_scores.mean(dim=(1,2)).to('cpu') + avg_scores_mean = avg_scores.mean(dim=(1,2)).to('cpu') + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, avg_scores={avg_scores_mean}, query_excess={query_excess}") # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number # of frames which is batch_size * seq_len. normalize by dividing by num heads. # also divide by ctx.limit so it's like penalizing a relative excess. query_scores.backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / (num_heads * ctx.limit))) - return attn_scores_grad + attn_scores.grad, None, None, None + return attn_scores_grad + attn_scores.grad, None, None, None, None @@ -1338,8 +1345,9 @@ def forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - attn_scores_limit = 12.0 # limit on our metric that affects how much grad we are likely to backpropagate. - attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, self.name) + attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, 0.1 * aux_loss_scale, + key_padding_mask, self.name) # We use our own version of softmax, defined in scaling.py, which should From ef0e92a51cbca7b5750fe53590e7b4a8046e28a2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 14:22:47 +0800 Subject: [PATCH 0642/1191] Implement CorrelationLimiter, use it with limit 0.25*0.1*0.1. --- egs/librispeech/ASR/zipformer/scaling.py | 102 +++++++++++++++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 6 ++ 2 files changed, 108 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2067820ef1..7769ffb769 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1591,6 +1591,108 @@ def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: aux_loss_scale, self.name) +class CorrelationLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, limit: float, aux_loss_scale: float, mask: Optional[Tensor], name: str): + ctx.save_for_backward(x, y) + ctx.limit = limit + ctx.mask = mask + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 + x, y = ctx.saved_tensors + dim = x.shape[-1] + mask = ctx.mask + limit = ctx.limit + aux_loss_scale = ctx.aux_loss_scale + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + if mask: + mask = (~mask).to(x.dtype).unsqueeze(-1) + x = x * mask + x, y = x.to(torch.float), y.to(torch.float) + x, y = x.detach(), y.detach() + + X, Y = x.reshape(-1, dim), y.reshape(-1, dim) + + X.requires_grad = True + Y.requires_grad = True + N = X.shape[0] + M = 32 # number of random vectors, this should be more than enough. + r = torch.randn(M, dim) # (M, dim) + r = torch.matmul(r, X.t()) # (M, N) + r = torch.matmul(r, Y) # (M, dim) + r = r * (1. / N) + r = torch.matmul(r, Y.t()) # (M, N) + r = torch.matmul(r, X) # (M, dim) + r = r * (1. / N) + + metric = (r ** 2).mean() + # now, with reference to the comment for class CorrelationLimiter, + # metric should equal an estimate of tr(M^T M M^T M) / dim. + + metric = metric ** (1/4) + # now we have a metric that's proportional in scale to the eigenvalues of + # M, where M is E[x y^T] + + loss = (metric - limit).relu() + + if random.random() < 0.001: + logging.info( + f"CorrelationLimiter: name={ctx.name}, limit={limit}, " + f"metric={metric.item()}, loss={loss.item()}, " + f"loss_scale={aux_loss_scale}" + ) + + loss.backward(torch.full_like(loss, aux_loss_scale)) + + return x.grad, y.grad, None, None, None, None + + +class CorrelationLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if feature x and feature y are too correlated, + based on a randomized algorithm. The correlation limit is specified based + on a limit on tr(M^T M M^T M) / dim, where: + M = E [x y^t], + and this is the same as the mean of the [singular values of M taken to the + fourth power.] We can measure this, concretely, as: + E[ ||M^T M n||_2^2 ] + where n is Gaussian noise, we do this for several vectors x. We can + compute M^T M n as mean[ x_i y_i^T y_i x_i^T n ] = (X^T (Y (Y^T (X n)))). + + Now, the eigenvalues of M should be related to the [magnitude of x] * [magnitude of y] + * [some factor that expresses their correlation.], and the magnitudes of x and y + are assumed to be fixed by the SoftNorm. So the user can express a limit in terms + of a [expected magnitude of x] * [expected magnitude of y] * [max correlation], + and we can take E[ ||M^T M n||_2^2 ] to the power 1/4 before comparing to the limit. + + Assumes channel dim is -1 and the input shape has >1 dimension. + """ + def __init__(self, limit: FloatLike): + # dimensionally, limit is [expected magnitude of x] * [expected magnitude of y] * [max correlation coefficient] + super().__init__() + self.name = None + self.limit = limit + + + def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: + # returns a scalar tensor that should be included in the loss function with: + # z = with_loss(z, ret, None) + # where z is any quantity that will be used in calculating the main loss. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return torch.tensor(0.0, device=x.device) + else: + return CorrelationLimiterFunction.apply(x, y, float(self.limit), + aux_loss_scale, mask, + self.name) + + + + def penalize_abs_values_gt( x: Tensor, limit: float, penalty: float, name: str = None ) -> Tensor: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 33ac45e322..9ba6f12544 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -46,6 +46,7 @@ try: from scaling import NormChangeLoss from scaling import MaxVarLoss + from scaling import CorrelationLimiter except: pass @@ -559,6 +560,7 @@ def __init__( self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) + self.offset_correlation_limiter = CorrelationLimiter(limit=0.25 * 0.1 * 0.1) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -643,6 +645,10 @@ def forward( offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) + offset = with_loss(offset, + self.offset_correlation_limiter(offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), + aux_loss_scale, mask=src_key_padding_mask)) + src = src_orig + offset src = with_loss(src, From 6f9c481c36096435d980a33843489f091a2e1f3f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 14:30:08 +0800 Subject: [PATCH 0643/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7769ffb769..6db4b65933 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1610,7 +1610,7 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 aux_loss_scale = ctx.aux_loss_scale with torch.enable_grad(): with torch.amp.autocast('cuda', enabled=False): - if mask: + if mask is not None: mask = (~mask).to(x.dtype).unsqueeze(-1) x = x * mask x, y = x.to(torch.float), y.to(torch.float) From 7abd5e170d5462e2ec11d3c85805504e7b4c4466 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 14:32:57 +0800 Subject: [PATCH 0644/1191] Bug fix regarding device --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6db4b65933..6062662369 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1622,7 +1622,7 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 Y.requires_grad = True N = X.shape[0] M = 32 # number of random vectors, this should be more than enough. - r = torch.randn(M, dim) # (M, dim) + r = torch.randn(M, dim, device=x.device) # (M, dim) r = torch.matmul(r, X.t()) # (M, N) r = torch.matmul(r, Y) # (M, dim) r = r * (1. / N) From 15e001413542fc4ca1e1bfea5e184b4a54d010f1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 14:49:19 +0800 Subject: [PATCH 0645/1191] multiply aux_loss_scale by N. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6062662369..f724230125 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1647,7 +1647,7 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 f"loss_scale={aux_loss_scale}" ) - loss.backward(torch.full_like(loss, aux_loss_scale)) + loss.backward(torch.full_like(loss, aux_loss_scale * N)) return x.grad, y.grad, None, None, None, None From a8119f0d6bf1a998b773683c4e9ade3746647898 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 14:52:29 +0800 Subject: [PATCH 0646/1191] Cosmetic update --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index f724230125..c46db9ec94 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1647,7 +1647,7 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 f"loss_scale={aux_loss_scale}" ) - loss.backward(torch.full_like(loss, aux_loss_scale * N)) + loss.backward(gradient=torch.full_like(loss, aux_loss_scale * N)) return x.grad, y.grad, None, None, None, None From e16f0d440b3fe480d79bdf23273a9ed929387a2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 14:57:39 +0800 Subject: [PATCH 0647/1191] Bug fix to have gradients be actually be computed. --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c46db9ec94..865aff2706 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1615,11 +1615,11 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 x = x * mask x, y = x.to(torch.float), y.to(torch.float) x, y = x.detach(), y.detach() + x.requires_grad = True + y.requires_grad = True X, Y = x.reshape(-1, dim), y.reshape(-1, dim) - X.requires_grad = True - Y.requires_grad = True N = X.shape[0] M = 32 # number of random vectors, this should be more than enough. r = torch.randn(M, dim, device=x.device) # (M, dim) From 366c39cb94765c53533bf8ad12ed584a3ffd2430 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 15:23:12 +0800 Subject: [PATCH 0648/1191] Improve printout of ScaleLimiter --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 865aff2706..e6ac701125 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1563,7 +1563,7 @@ def backward(ctx, x_grad: Tensor): if random.random() < 0.002: logging.info( f"ScaleLimiter: name={ctx.name}, max_rms={ctx.max_rms}, " - f"excess={excess.item()}, " + f"rms={rms.mean().item()}, excess={excess.item()}, " f"loss_scale={ctx.aux_loss_scale}" ) excess.backward(gradient=torch.full_like(excess, ctx.aux_loss_scale * numel)) From 188302f77ee0d50b6b680203d3482af00ecbbb91 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 16:29:08 +0800 Subject: [PATCH 0649/1191] Make the limit of correlation limiter 2.5 times larger and the loss scale 10 times smaller. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9ba6f12544..9cef6b9f5f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -560,7 +560,7 @@ def __init__( self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) - self.offset_correlation_limiter = CorrelationLimiter(limit=0.25 * 0.1 * 0.1) + self.offset_correlation_limiter = CorrelationLimiter(limit=0.25 * 0.1 * 0.25) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -647,7 +647,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter(offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask)) + 0.1 * aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From ccfa0643a6d817a52d30ca20b2f1e03bf695beb5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 16:57:58 +0800 Subject: [PATCH 0650/1191] Remove max_var_loss{1,2}, cosine_loss and offset_cosine_loss from Zipformer2EncoderLayer. --- egs/librispeech/ASR/zipformer/zipformer.py | 25 +++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9cef6b9f5f..9c156f1b75 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -555,10 +555,10 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) - self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) + #self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) + #self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + #self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) + #self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) self.offset_correlation_limiter = CorrelationLimiter(limit=0.25 * 0.1 * 0.25) @@ -634,14 +634,14 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = with_loss(src, - self.max_var_loss1((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + #src = with_loss(src, + # self.max_var_loss1((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale - offset = with_loss(offset, - self.max_var_loss2(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + #offset = with_loss(offset, + # self.max_var_loss2(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) @@ -651,13 +651,14 @@ def forward( src = src_orig + offset - src = with_loss(src, - self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + #src = with_loss(src, + # self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) # also put cosine_loss on src, mostly because it will be used in scale_limiter and we don't want the # network to get around the scale limitation by using an offset. - src = with_loss(src, - self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + + #src = with_loss(src, + # self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) src = self.norm(src) From 05a6905a4ee3dfb84ed3e29b0c5ac112c26f53be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 19:52:39 +0800 Subject: [PATCH 0651/1191] Change power on eigenvalues of M from 4 to 2. --- egs/librispeech/ASR/zipformer/scaling.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e6ac701125..2b05fbd1d1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1626,17 +1626,11 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 r = torch.matmul(r, X.t()) # (M, N) r = torch.matmul(r, Y) # (M, dim) r = r * (1. / N) - r = torch.matmul(r, Y.t()) # (M, N) - r = torch.matmul(r, X) # (M, dim) - r = r * (1. / N) - metric = (r ** 2).mean() + metric = (r ** 2).mean().sqrt() # now, with reference to the comment for class CorrelationLimiter, - # metric should equal an estimate of tr(M^T M M^T M) / dim. - - metric = metric ** (1/4) - # now we have a metric that's proportional in scale to the eigenvalues of - # M, where M is E[x y^T] + # metric should, I believe, equal an estimate of sqrt(tr(M^T M) / dim), + # which should be an rms of the singular values of M. loss = (metric - limit).relu() From 5e55141129ce932955e21bfb07eecaa7822ab2e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Oct 2025 21:54:05 +0800 Subject: [PATCH 0652/1191] Revert "Remove max_var_loss{1,2}, cosine_loss and offset_cosine_loss from Zipformer2EncoderLayer." This reverts commit ccfa0643a6d817a52d30ca20b2f1e03bf695beb5. --- egs/librispeech/ASR/zipformer/zipformer.py | 25 +++++++++++----------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9c156f1b75..9cef6b9f5f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -555,10 +555,10 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - #self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) - #self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - #self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) - #self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) + self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) + self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) + self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) + self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) self.offset_correlation_limiter = CorrelationLimiter(limit=0.25 * 0.1 * 0.25) @@ -634,14 +634,14 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - #src = with_loss(src, - # self.max_var_loss1((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + src = with_loss(src, + self.max_var_loss1((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale - #offset = with_loss(offset, - # self.max_var_loss2(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + offset = with_loss(offset, + self.max_var_loss2(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) @@ -651,14 +651,13 @@ def forward( src = src_orig + offset - #src = with_loss(src, - # self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + src = with_loss(src, + self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) # also put cosine_loss on src, mostly because it will be used in scale_limiter and we don't want the # network to get around the scale limitation by using an offset. - - #src = with_loss(src, - # self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) + src = with_loss(src, + self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) src = self.norm(src) From 3fa5c76fe23065072334ec61026b3b741c0a410d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Oct 2025 10:19:52 +0800 Subject: [PATCH 0653/1191] Modify correlation limiter to have no limit and use cross correlation across even/odd batch elements. Increase aux loss scale. --- egs/librispeech/ASR/zipformer/scaling.py | 67 ++++++++-------------- egs/librispeech/ASR/zipformer/zipformer.py | 7 ++- 2 files changed, 29 insertions(+), 45 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2b05fbd1d1..97ca8d6991 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1593,9 +1593,8 @@ def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: class CorrelationLimiterFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, y: Tensor, limit: float, aux_loss_scale: float, mask: Optional[Tensor], name: str): + def forward(ctx, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Tensor], name: str): ctx.save_for_backward(x, y) - ctx.limit = limit ctx.mask = mask ctx.aux_loss_scale = aux_loss_scale ctx.name = name @@ -1606,8 +1605,8 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 x, y = ctx.saved_tensors dim = x.shape[-1] mask = ctx.mask - limit = ctx.limit aux_loss_scale = ctx.aux_loss_scale + (batch_size, seq_len, num_channels) = x.shape with torch.enable_grad(): with torch.amp.autocast('cuda', enabled=False): if mask is not None: @@ -1615,72 +1614,56 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 x = x * mask x, y = x.to(torch.float), y.to(torch.float) x, y = x.detach(), y.detach() + x_orig, y_orig = x, y x.requires_grad = True y.requires_grad = True + half_batch = batch_size // 2 + if half_batch <= 1: + # the reason we also return None if half_batch==1 is because of CR-CTC + # where they may really be duplicates + return None, None, None, None, None - X, Y = x.reshape(-1, dim), y.reshape(-1, dim) + x = x[:2*half_batch] + y = y[:2*half_batch] - N = X.shape[0] - M = 32 # number of random vectors, this should be more than enough. - r = torch.randn(M, dim, device=x.device) # (M, dim) - r = torch.matmul(r, X.t()) # (M, N) - r = torch.matmul(r, Y) # (M, dim) - r = r * (1. / N) + M = 64 # number of random vectors, this should be more than enough. + r = torch.randn(half_batch, M, dim, device=x.device).repeat_interleave(2, dim=0) + # r: (batch_size, M, dim) + r = torch.matmul(x, r.transpose(1, 2)) # (batch_size, seq_len, m) + r = torch.matmul(r.transpose(1, 2), y) # (batch_size, m, dim) - metric = (r ** 2).mean().sqrt() - # now, with reference to the comment for class CorrelationLimiter, - # metric should, I believe, equal an estimate of sqrt(tr(M^T M) / dim), - # which should be an rms of the singular values of M. - - loss = (metric - limit).relu() + # correlation between tr(M) estimates between elements of the batch. + correlation = r[0::2] * r[1::2] if random.random() < 0.001: logging.info( - f"CorrelationLimiter: name={ctx.name}, limit={limit}, " - f"metric={metric.item()}, loss={loss.item()}, " - f"loss_scale={aux_loss_scale}" + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation.mean()}" ) - loss.backward(gradient=torch.full_like(loss, aux_loss_scale * N)) + correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale / num_channels)) - return x.grad, y.grad, None, None, None, None + return x_orig.grad, y_orig.grad, None, None, None class CorrelationLimiter(torch.nn.Module): """ - Adds a penalty in backprop if feature x and feature y are too correlated, - based on a randomized algorithm. The correlation limit is specified based - on a limit on tr(M^T M M^T M) / dim, where: - M = E [x y^t], - and this is the same as the mean of the [singular values of M taken to the - fourth power.] We can measure this, concretely, as: - E[ ||M^T M n||_2^2 ] - where n is Gaussian noise, we do this for several vectors x. We can - compute M^T M n as mean[ x_i y_i^T y_i x_i^T n ] = (X^T (Y (Y^T (X n)))). - - Now, the eigenvalues of M should be related to the [magnitude of x] * [magnitude of y] - * [some factor that expresses their correlation.], and the magnitudes of x and y - are assumed to be fixed by the SoftNorm. So the user can express a limit in terms - of a [expected magnitude of x] * [expected magnitude of y] * [max correlation], - and we can take E[ ||M^T M n||_2^2 ] to the power 1/4 before comparing to the limit. - - Assumes channel dim is -1 and the input shape has >1 dimension. + Adds a penalty in backprop if feature x and feature y are correlated. + Assumes input is (batch, seq, channel) """ - def __init__(self, limit: FloatLike): - # dimensionally, limit is [expected magnitude of x] * [expected magnitude of y] * [max correlation coefficient] + def __init__(self): super().__init__() self.name = None - self.limit = limit def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: + # x and y should both be: (batch, seq, channel) # returns a scalar tensor that should be included in the loss function with: # z = with_loss(z, ret, None) # where z is any quantity that will be used in calculating the main loss. if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return torch.tensor(0.0, device=x.device) else: - return CorrelationLimiterFunction.apply(x, y, float(self.limit), + return CorrelationLimiterFunction.apply(x, y, aux_loss_scale, mask, self.name) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9cef6b9f5f..408c543a69 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -560,7 +560,7 @@ def __init__( self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) - self.offset_correlation_limiter = CorrelationLimiter(limit=0.25 * 0.1 * 0.25) + self.offset_correlation_limiter = CorrelationLimiter() self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -646,8 +646,9 @@ def forward( offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) offset = with_loss(offset, - self.offset_correlation_limiter(offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - 0.1 * aux_loss_scale, mask=src_key_padding_mask)) + self.offset_correlation_limiter( + offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), + aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From d75a05134809898bea31a071aa3044319ecd6a2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Oct 2025 10:30:31 +0800 Subject: [PATCH 0654/1191] Correct for seq_len --- egs/librispeech/ASR/zipformer/scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 97ca8d6991..8031f89e20 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1631,6 +1631,7 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 # r: (batch_size, M, dim) r = torch.matmul(x, r.transpose(1, 2)) # (batch_size, seq_len, m) r = torch.matmul(r.transpose(1, 2), y) # (batch_size, m, dim) + r = r * (1. / seq_len) # correlation between tr(M) estimates between elements of the batch. correlation = r[0::2] * r[1::2] From 101baed25e8f9f6c51d8e2681784871f24a8b3e1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Oct 2025 20:05:28 +0800 Subject: [PATCH 0655/1191] Reduce aux_loss scale by factor of 10 for correlation limiter --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 408c543a69..a99b7068a2 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -648,7 +648,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask)) + 0.1 * aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 8cff0d47b99563125b9244a404bdabee937f7d1f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Oct 2025 20:24:41 +0800 Subject: [PATCH 0656/1191] Remove two cosine losses from the zapformer layer. --- egs/librispeech/ASR/zipformer/zipformer.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a99b7068a2..1544d1c51c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -555,8 +555,6 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.7)) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) @@ -652,14 +650,6 @@ def forward( src = src_orig + offset - src = with_loss(src, - self.offset_cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) - - # also put cosine_loss on src, mostly because it will be used in scale_limiter and we don't want the - # network to get around the scale limitation by using an offset. - src = with_loss(src, - self.cosine_loss(src.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) - src = self.norm(src) return src From 7ec5ff5d22ae7a090a0c7071e3696b87241dad83 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Oct 2025 21:44:11 +0800 Subject: [PATCH 0657/1191] Remove cosine losses at zipformer level --- egs/librispeech/ASR/zipformer/zipformer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1544d1c51c..a1ba8188cd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -811,8 +811,9 @@ def __init__( self.copy_bypass = Identity() - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + #self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) + #self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) + self.offset_max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (5000.0, 0.25), default=1.0)) @@ -876,12 +877,6 @@ def forward( # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, # because of some identities involving orthogonal matrices. - if aux_loss_scale: - src = with_loss(src, - self.offset_cosine_loss(offset.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask) + - self.cosine_loss(src.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask)) return src From 58ee68ca14edf94a469029771aa1c32b5c21c7d3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Oct 2025 08:47:59 +0800 Subject: [PATCH 0658/1191] Remove all cosine losses. --- egs/librispeech/ASR/zipformer/subsampling.py | 2 - egs/librispeech/ASR/zipformer/zipformer.py | 49 -------------------- 2 files changed, 51 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 44bc01c0c9..41d3cd9510 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -234,7 +234,6 @@ def __init__( self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, initial_scale=4.0) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(out_channels, power=0.75)) self.scale_limiter = ScaleLimiter(max_rms=2.0) @@ -277,7 +276,6 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) - x = with_loss(x, self.cosine_loss(x, aux_loss_scale, key_padding_mask), None) x = self.scale_limiter(x, aux_loss_scale) x = self.out_norm(x) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a1ba8188cd..ace7070e08 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -444,31 +444,6 @@ def get_init_states( return states -def get_max_similarity(rank: int, power: float): - """ - This returns a value for the "max_similarity" argument of CosineSimilarityLoss. - the max_similarity is an upper limit we impose on the mean value of (x_i . x_j) - if i != j are two different sequence-position indexes and x_i and x_j are - activation vectors normalized to have unit length. - - rank: the dimension of the space, usually this is the num_channels, but if - we have just up-projected from a bottleneck, it would be the bottleneck - dimension. - power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean - we enforce the vector dimensions to be completely independent like Gaussian noise - (don't do this); if we set power=0.0 it would be equivalent to not having - the CosineSimilarityLoss at all. - - The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal - variable. If x consists of independent Gaussian noise of dimension D, with - variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" - would be close to a no-op for large D), then (x_i . x_j) would be distributed as - a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) - would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value - between this and 1, as a kind of heuristic limit on this max_similarity. - """ - return (0.7978845608 / (rank ** 0.5)) ** power - def pad_mask(mask: Optional[Tensor], seq_len: int): # mask: (batch_size, old_seq_len) # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). @@ -811,8 +786,6 @@ def __init__( self.copy_bypass = Identity() - #self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) - #self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) self.offset_max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (5000.0, 0.25), default=1.0)) @@ -1202,8 +1175,6 @@ def __init__( ) - self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) - self.linear_pos = ScaledLinear( pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 ) @@ -1257,13 +1228,6 @@ def forward( p = q[..., :pos_head_dim] p = self.copy_pos_query(p) # diagnostics only - if aux_loss_scale: - k = with_loss(k, - self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), - aux_loss_scale / num_heads, - key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None)) - - # time1 refers to target, time2 refers to source. q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) @@ -1516,7 +1480,6 @@ def __init__( f = max(1.0, embed_dim / (num_heads * value_head_dim)) - self.cosine_loss = CosineSimilarityLoss(max_similarity=0.75) def forward( @@ -1559,11 +1522,6 @@ def forward( # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) - if aux_loss_scale: - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), - aux_loss_scale, - mask=src_key_padding_mask)) - return x def streaming_forward( @@ -1644,13 +1602,10 @@ def __init__(self, embed_dim: int, feedforward_dim: int): initial_scale=0.5, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.65)) - def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: x = self.in_proj(x) x = self.out_proj(x) - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask)) return x @@ -1715,8 +1670,6 @@ def __init__( dropout_p=0.0, initial_scale=0.05, ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.6)) - def forward( self, @@ -1770,8 +1723,6 @@ def forward( x = self.out_proj(x) # (time, batch, channels) - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask)) - return x def streaming_forward( From c4a8c45d0e2f143d89229410a37d08afd0ea41d7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 30 Oct 2025 09:29:52 +0800 Subject: [PATCH 0659/1191] Add four more layers in central stack --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7f73b64115..dc9c32f193 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,12,12,12,9", + default="6,9,12,16,12,9", help="Number of zipformer encoder layers per stack, comma separated.", ) From f266e3f19c54b96be7491c0dd395f3433561237e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 1 Nov 2025 15:51:36 +0800 Subject: [PATCH 0660/1191] Make middle stack very deep, remove two stacks. --- egs/librispeech/ASR/zapformer/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index dc9c32f193..ee384fe2da 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,14 +185,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,12,16,12,9", + default="6,9,30,9", help="Number of zipformer encoder layers per stack, comma separated.", ) parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,8,4,2", + default="1,2,4,2", help="Downsampling factor for each stack of encoder layers.", ) @@ -213,21 +213,21 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-multiple", type=str, - default="3,3,3,3,3,3", + default="3,3,3,3", help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", ) parser.add_argument( "--num-heads", type=str, - default="4,4,4,8,4,4", + default="4,4,4,4", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) parser.add_argument( "--encoder-multiple", type=str, - default="4,6,9,12,9,6", + default="4,6,10,6", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31,31,15,15,15,31", + default="31,31,15,31", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) From a9343c4588a1385a876b59bf5866aad9d5bc8af0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 1 Nov 2025 15:53:38 +0800 Subject: [PATCH 0661/1191] Combine middle three stacks into one very deep one. Increase feedforward multiple to 4. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ee384fe2da..0bf6495817 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -213,7 +213,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-multiple", type=str, - default="3,3,3,3", + default="4", help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", ) From 3c92b14f391f5f330275656f74c88fca21d73f06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 1 Nov 2025 16:36:19 +0800 Subject: [PATCH 0662/1191] Remove factor of 0.1 on the correlation loss aux_loss_scale. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ace7070e08..347516bdc8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -621,7 +621,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - 0.1 * aux_loss_scale, mask=src_key_padding_mask)) + aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From a174385522c1a24fca6449380d0cb17897d08004 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 1 Nov 2025 19:42:58 +0800 Subject: [PATCH 0663/1191] Reduce middle num layers from 30 to 20. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 0bf6495817..63da526f62 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,30,9", + default="6,9,20,9", help="Number of zipformer encoder layers per stack, comma separated.", ) From 00148bc72da90ee928dff8fcafe2327def7d4fe8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 1 Nov 2025 22:18:01 +0800 Subject: [PATCH 0664/1191] Increase central num layers from 20 to 24. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 63da526f62..7935fb12f2 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,20,9", + default="6,9,24,9", help="Number of zipformer encoder layers per stack, comma separated.", ) From e75720436ada716116f592dccf05ad907d8d66a4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Nov 2025 09:45:10 +0800 Subject: [PATCH 0665/1191] Increase encoder-multiple from 4,6,10,6 to 5,8,12,8, decrease central num-layers 24->22. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7935fb12f2..95bcf048a2 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,24,9", + default="6,9,22,9", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -227,7 +227,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-multiple", type=str, - default="4,6,10,6", + default="5,8,12,8", help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", ) From 958351ca00991c5c13f7a9984d73ef570ecd4938 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Nov 2025 15:27:08 +0800 Subject: [PATCH 0666/1191] Change power in Eden schedule from -0.5 to -0.4. --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6c01fbdff1..131d06f736 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1102,7 +1102,7 @@ class Eden2(LRScheduler): only batches. The basic formula (before warmup) is: - lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.4) * warmup where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. @@ -1133,7 +1133,7 @@ def __init__( def get_lr(self): factor = ( (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.5 + ) ** -0.4 warmup_factor = ( 1.0 if self.batch >= self.warmup_batches From 0427f4141bc83d4fd22af3e569aba250a4455ae9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Nov 2025 15:38:35 +0800 Subject: [PATCH 0667/1191] Reduce power in Eden further from -0.4 to -0.33. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 95bcf048a2..ae6ba80d09 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1403,7 +1403,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.66) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 131d06f736..cf1134d813 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1102,7 +1102,7 @@ class Eden2(LRScheduler): only batches. The basic formula (before warmup) is: - lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.4) * warmup + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.33) * warmup where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. @@ -1133,7 +1133,7 @@ def __init__( def get_lr(self): factor = ( (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.4 + ) ** -0.33 warmup_factor = ( 1.0 if self.batch >= self.warmup_batches From c4687788b4eb92b7ac6848587e513d1e9de20965 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Nov 2025 16:05:20 +0800 Subject: [PATCH 0668/1191] Take optim.py from 1526. --- egs/librispeech/ASR/zipformer/optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index cf1134d813..008c176f41 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1102,7 +1102,7 @@ class Eden2(LRScheduler): only batches. The basic formula (before warmup) is: - lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.33) * warmup + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. @@ -1133,7 +1133,7 @@ def __init__( def get_lr(self): factor = ( (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.33 + ) ** -0.5 warmup_factor = ( 1.0 if self.batch >= self.warmup_batches @@ -1197,13 +1197,13 @@ def __init__( lr_batches: Union[int, float], warmup_batches: Union[int, float] = 500.0, warmup_start: float = 0.5, - p: float = 1.0, + power: float = 1.0, verbose: bool = False, ): super().__init__(optimizer, verbose) self.lr_batches = lr_batches self.warmup_batches = warmup_batches - self.p = p + self.power = power assert 0.0 <= warmup_start <= 1.0, warmup_start self.warmup_start = warmup_start @@ -1211,7 +1211,7 @@ def get_lr(self): lr_batches = self.lr_batches e = 2.71828 batch = self.batch - p = self.p + p = self.power factor = ((p * lr_batches / batch) ** p if batch > p * e * lr_batches else e ** (-batch / (e * lr_batches))) From 01601331a8917e71186126d615a370a834c9da2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Nov 2025 21:59:01 +0800 Subject: [PATCH 0669/1191] Change power in schedule from 0.66 to 0.5 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ae6ba80d09..ca45138f7f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1403,7 +1403,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.66) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 02ab64409e882080d75d0ccd0892f2d641633102 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 2 Nov 2025 22:01:15 +0800 Subject: [PATCH 0670/1191] Increase valid_interval from 3000 to 10000. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ca45138f7f..bc173a93ce 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -672,7 +672,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "valid_interval": 10000, # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. From e654e8a948be9c5f1e5b1da32bee2a2f3325b8e0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Nov 2025 16:50:51 +0800 Subject: [PATCH 0671/1191] Combine the two self_attn modules into one. The value_head_dim is now the specified value, not twice that value. --- egs/librispeech/ASR/zipformer/zipformer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 347516bdc8..e6e4fadbd7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -544,14 +544,13 @@ def __init__( pos_head_dim=pos_head_dim, ) - self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] + self.self_attn = SelfAttention(embed_dim, num_heads, value_head_dim) feedforward_dim = embed_dim * feedforward_multiple - self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4) + self.feed_forward1 = FeedforwardModule(embed_dim, feedforward_dim) self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim) - self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) self.norm = ExpNorm(embed_dim) @@ -595,13 +594,10 @@ def forward( key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale, ) - attn_weights1, attn_weights2 = attn_weights.chunk(2, dim=0) - - src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) From c2bd6f881599af8f31a693ddb2a1640fbaaa4d6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Nov 2025 16:52:09 +0800 Subject: [PATCH 0672/1191] Increase central num heads from 4 to 6. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index bc173a93ce..4fe3e22276 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -220,7 +220,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-heads", type=str, - default="4,4,4,4", + default="4,4,6,4", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) From fec523e3cb6b206dda99d926bab76f22c3ef86c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Nov 2025 16:54:26 +0800 Subject: [PATCH 0673/1191] Reduce central feedforward-multiple from 4 to 3. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 4fe3e22276..4282d75166 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -213,7 +213,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-multiple", type=str, - default="4", + default="4,4,3,4", help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", ) From 7af8f2b8d1261daa990f9f652b8449c20b2dff48 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Nov 2025 17:06:53 +0800 Subject: [PATCH 0674/1191] Reduce central num layers from 22 to 20. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 4282d75166..99139a4f14 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,22,9", + default="6,9,20,9", help="Number of zipformer encoder layers per stack, comma separated.", ) From ca21468ea243a819d9ee057b9eaed8e86969e148 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Nov 2025 17:10:40 +0800 Subject: [PATCH 0675/1191] Increase value-head-from from 12 to 20. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 99139a4f14..5fe00ef158 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -241,7 +241,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--value-head-dim", type=str, - default="12", + default="20", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From 0addb585446d9241c34b64640de4e7c101200f69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Nov 2025 12:22:27 +0800 Subject: [PATCH 0676/1191] Increase central num layers from 20 to 26 and decrease central num-heads from 6 to 4. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5fe00ef158..b3c6fb9a09 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,20,9", + default="6,9,26,9", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -220,7 +220,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-heads", type=str, - default="4,4,6,4", + default="4,4,4,4", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) From 6335e25549eba856e342a4b2a5dcdb1b91e6fbc7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Nov 2025 12:38:04 +0800 Subject: [PATCH 0677/1191] Remove max_var_loss and scale_limiter instances from zipformer and frontend. --- egs/librispeech/ASR/zipformer/subsampling.py | 4 ---- egs/librispeech/ASR/zipformer/zipformer.py | 17 ----------------- 2 files changed, 21 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 41d3cd9510..b901568fed 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -21,7 +21,6 @@ import torch from scaling import ( - ScaleLimiter, ScaledLinear, ExpNorm, FloatLike, @@ -235,8 +234,6 @@ def __init__( initial_scale=4.0) - self.scale_limiter = ScaleLimiter(max_rms=2.0) - self.out_norm = ExpNorm(out_channels) def forward( @@ -276,7 +273,6 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) - x = self.scale_limiter(x, aux_loss_scale) x = self.out_norm(x) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e6e4fadbd7..e5f402326c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -40,12 +40,9 @@ limit_param_value, penalize_abs_values_gt, softmax, - ScaleLimiter, with_loss, ) try: - from scaling import NormChangeLoss - from scaling import MaxVarLoss from scaling import CorrelationLimiter except: pass @@ -530,9 +527,6 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - self.max_var_loss1 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.2), default=1.0)) - self.max_var_loss2 = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (10000.0, 0.1), default=1.0)) - self.offset_scale_limiter = ScaleLimiter(max_rms=0.25) self.offset_correlation_limiter = CorrelationLimiter() @@ -603,17 +597,9 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = with_loss(src, - self.max_var_loss1((src - src_orig).permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) - residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale - offset = with_loss(offset, - self.max_var_loss2(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask)) - - offset = self.offset_scale_limiter(offset, 0.05 * aux_loss_scale) - offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), @@ -782,8 +768,6 @@ def __init__( self.copy_bypass = Identity() - self.offset_max_var_loss = MaxVarLoss(max_rms=ScheduledFloat((0.0, 0.5), (5000.0, 0.25), default=1.0)) - def forward( @@ -1616,7 +1600,6 @@ class ConvolutionModule(nn.Module): bias (bool): Whether to use bias in conv layers (default=True). """ - def __init__( self, channels: int, From 8ef379ac90a07918e4e6430daaa1f422d71d71f6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Nov 2025 13:00:38 +0800 Subject: [PATCH 0678/1191] Increase middle conv kernel size from 15 to 31. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b3c6fb9a09..a74ef18fcb 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31,31,15,31", + default="31,31,31,31", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) From 8c9b83463b4e2eedd82277c72d7293375c7a2653 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Nov 2025 13:09:19 +0800 Subject: [PATCH 0679/1191] Revert change to subsampling.py --- egs/librispeech/ASR/zipformer/subsampling.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index b901568fed..41d3cd9510 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -21,6 +21,7 @@ import torch from scaling import ( + ScaleLimiter, ScaledLinear, ExpNorm, FloatLike, @@ -234,6 +235,8 @@ def __init__( initial_scale=4.0) + self.scale_limiter = ScaleLimiter(max_rms=2.0) + self.out_norm = ExpNorm(out_channels) def forward( @@ -273,6 +276,7 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) + x = self.scale_limiter(x, aux_loss_scale) x = self.out_norm(x) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) From 38490c809406364771e7a4851a86fc7118ab4845 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Nov 2025 00:07:04 +0800 Subject: [PATCH 0680/1191] Replace convolution modules with fft based convolutions. --- egs/librispeech/ASR/zipformer/zipformer.py | 116 +++++++++++++++------ 1 file changed, 86 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e5f402326c..bd33e2d1d3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1588,6 +1588,65 @@ def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: x = self.out_proj(x) return x +def round_up_to_power_of_two(x): + x = x - 1 + x = x | x >> 1 + x = x | x >> 2 + x = x | x >> 4 + x = x | x >> 8 + x = x | x >> 16 + x = x + 1 + return x + + +class FftModule(nn.Module): + def __init__(self, + num_channels: int, + params_per_channel: int, + min_pad: int = 32): + super().__init__() + # initialize to identity function. + self.weight = nn.Parameter( torch.stack((torch.ones(num_channels, params_per_channel), + torch.zeros(num_channels, params_per_channel)), + dim=0)) + # self.weight: (2, num_channels, params_per_channel)( + self.min_pad = min_pad + + + def forward(self, + x: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + + n = round_up_to_power_of_two(seq_len + self.min_pad) + x = torch.fft.rfft(x, n=n, dim=0, norm="ortho") + + N = x.shape[0] # N == n/2 + 1, the number of fourier components. + # x: (N, batch_size, num_channels) + + weight = self.upsample_weight(N) + # weight: (num_channels, N) + + x = x * weight.t().unsqueeze(1) + + x = torch.fft.irfft(x, dim=0, norm="ortho") + + return x[:seq_len] + + + def upsample_weight(self, N: int) -> Tensor: + # N is the desired number of frequencies of weight, so we return + # a complex weight of shape (num_channels, N). + + weight = self.weight + num_channels = weight.shape[0] // 2 + # the following may not be ideal, we'll see. + weight = torch.nn.functional.upsample(weight, N, mode='linear', align_corners=True) + + weight = torch.view_as_complex(weight.permute(1, 2, 0).contiguous()) + # weight: (num_channels, N) + return weight + + class ConvolutionModule(nn.Module): @@ -1630,17 +1689,10 @@ def __init__( assert kernel_size % 2 == 1 - self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) - if causal - else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - ) + self.fft_conv = FftModule(num_channels=bottleneck_dim, + params_per_channel=kernel_size, + min_pad=32) + self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1677,28 +1729,12 @@ def forward( x = x * s x = self.activation2(x) # identity - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). + #x: (time, batch, channels) if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if ( - not torch.jit.is_scripting() - and not torch.jit.is_tracing() - and chunk_size >= 0 - ): - # Not support exporting a model for simulated streaming decoding - assert ( - self.causal - ), "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = x.permute(2, 0, 1) # (time, batch, channels) + x = self.fft_conv(x) x = self.out_proj(x) # (time, batch, channels) @@ -1787,6 +1823,25 @@ def _test_zipformer_main(causal: bool = False): ) x_ # to remove flake8 warnings +def _test_fft_module(): + num_channels = 110 + f = FftModule(num_channels=num_channels, + params_per_channel=10, + min_pad=4) + + batch_size = 5 + seq_len = 50 + x = torch.randn(seq_len, batch_size, num_channels) + + y = f(x) + + def rms(a): + return (a**2).mean().item() + + print(f"rms(y)={rms(y)}, rms(x-y)={rms(x-y)}") + + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) @@ -1794,3 +1849,4 @@ def _test_zipformer_main(causal: bool = False): torch.set_num_interop_threads(1) _test_zipformer_main(False) _test_zipformer_main(True) + _test_fft_module() From b1fa3984ce208dec6267eaf4bb424bfe8de066bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Nov 2025 10:29:37 +0800 Subject: [PATCH 0681/1191] Restore scale_limiter, but with larger scale, to avoid extremely large values developing. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bd33e2d1d3..bf8eb7391d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -40,6 +40,7 @@ limit_param_value, penalize_abs_values_gt, softmax, + ScaleLimiter, with_loss, ) try: @@ -527,6 +528,7 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) self.offset_correlation_limiter = CorrelationLimiter() @@ -600,6 +602,8 @@ def forward( residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) offset = (src - src_orig) * residual_scale + offset = self.offset_scale_limiter(offset, aux_loss_scale) + offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), From 8e1b2a47b78f112507abf1276c220659eab05809 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Nov 2025 20:34:35 +0800 Subject: [PATCH 0682/1191] Have convolution be done twice, once with transpose; activation in middle. --- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bf8eb7391d..d13cb1d184 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -36,6 +36,7 @@ CosineSimilarityLoss, ScheduledFloat, FloatLike, + SwashR, convert_num_channels, limit_param_value, penalize_abs_values_gt, @@ -1618,7 +1619,8 @@ def __init__(self, def forward(self, - x: Tensor) -> Tensor: + x: Tensor, + transpose: bool = False) -> Tensor: (seq_len, batch_size, num_channels) = x.shape n = round_up_to_power_of_two(seq_len + self.min_pad) @@ -1628,6 +1630,12 @@ def forward(self, # x: (N, batch_size, num_channels) weight = self.upsample_weight(N) + eps = 1.0e-05 + # half-normalize the weight. + weight = weight / (weight.abs() + eps).sqrt() + if transpose: + # reverse the time direction of the kernel. + weight = weight.conj() # weight: (num_channels, N) x = x * weight.t().unsqueeze(1) @@ -1697,12 +1705,11 @@ def __init__( params_per_channel=kernel_size, min_pad=32) + self.activation3 = SwashR() - self.out_proj = ActivationDropoutAndLinear( + self.out_proj = ScaledLinear( bottleneck_dim, channels, - activation="SwashR", - dropout_p=0.0, initial_scale=0.05, ) @@ -1740,6 +1747,10 @@ def forward( x = self.fft_conv(x) + x = self.activation3(x) + + x = self.fft_conv(x, transpose=True) + x = self.out_proj(x) # (time, batch, channels) return x From d554714289d11ef369af5b4b6e92d705f0524da1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Nov 2025 20:53:41 +0800 Subject: [PATCH 0683/1191] Decouple first and second convolution. --- egs/librispeech/ASR/zipformer/zipformer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d13cb1d184..7f131b6f04 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1701,7 +1701,11 @@ def __init__( assert kernel_size % 2 == 1 - self.fft_conv = FftModule(num_channels=bottleneck_dim, + self.fft_conv1 = FftModule(num_channels=bottleneck_dim, + params_per_channel=kernel_size, + min_pad=32) + + self.fft_conv2 = FftModule(num_channels=bottleneck_dim, params_per_channel=kernel_size, min_pad=32) @@ -1745,11 +1749,11 @@ def forward( if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = self.fft_conv(x) + x = self.fft_conv1(x) x = self.activation3(x) - x = self.fft_conv(x, transpose=True) + x = self.fft_conv2(x) x = self.out_proj(x) # (time, batch, channels) From 5651613f11b971b476f6c4a68e38194aa7503929 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Nov 2025 23:32:51 +0800 Subject: [PATCH 0684/1191] Add back another activation at the end of the conv layer --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7f131b6f04..d7b207035d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1711,9 +1711,11 @@ def __init__( self.activation3 = SwashR() - self.out_proj = ScaledLinear( + self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, + activation="SwashR", + dropout_p=0.0, initial_scale=0.05, ) From c581555bbafc76b7af1d31cb37f22f1b90948cac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 5 Nov 2025 23:58:07 +0800 Subject: [PATCH 0685/1191] Introduce bias before out_proj. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d7b207035d..bd2f1bf305 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1711,6 +1711,8 @@ def __init__( self.activation3 = SwashR() + self.out_bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) + self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, @@ -1757,6 +1759,8 @@ def forward( x = self.fft_conv2(x) + x = x + self.out_bias + x = self.out_proj(x) # (time, batch, channels) return x From 1849ee903d8f119d25e2068947e8b7a24039314c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Nov 2025 11:15:28 +0800 Subject: [PATCH 0686/1191] Move masking to before in_proj; add another bias prior to middle activation. --- egs/librispeech/ASR/zipformer/zipformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bd2f1bf305..8755c9dff3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1614,6 +1614,7 @@ def __init__(self, self.weight = nn.Parameter( torch.stack((torch.ones(num_channels, params_per_channel), torch.zeros(num_channels, params_per_channel)), dim=0)) + self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) # self.weight: (2, num_channels, params_per_channel)( self.min_pad = min_pad @@ -1642,7 +1643,7 @@ def forward(self, x = torch.fft.irfft(x, dim=0, norm="ortho") - return x[:seq_len] + return x[:seq_len] + self.bias def upsample_weight(self, N: int) -> Tensor: @@ -1711,6 +1712,7 @@ def __init__( self.activation3 = SwashR() + self.middle_bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) self.out_bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) self.out_proj = ActivationDropoutAndLinear( @@ -1739,9 +1741,13 @@ def forward( Tensor: Output tensor (#time, batch, channels). """ - x = self.in_proj(x) # (time, batch, 2*channels) + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + x = self.in_proj(x) # (time, batch, 2*channels) + x, s = x.chunk(2, dim=2) s = self.sigmoid(s) x = self.activation1(x) # identity. @@ -1749,18 +1755,12 @@ def forward( x = self.activation2(x) # identity #x: (time, batch, channels) - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = self.fft_conv1(x) x = self.activation3(x) x = self.fft_conv2(x) - x = x + self.out_bias - x = self.out_proj(x) # (time, batch, channels) return x From 553d0cce1fc559b9ca2f209111a20969a6886764 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Nov 2025 16:06:26 +0800 Subject: [PATCH 0687/1191] Double kernel size / num params of fft_conv modules. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index a74ef18fcb..fbfd6b341f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31,31,31,31", + default="63,63,63,63", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) From ac474fc83598a298daa0bc8a5d111052b96f67f4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Nov 2025 16:51:11 +0800 Subject: [PATCH 0688/1191] Introduce projection into fft_conv modules. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8755c9dff3..940c132db9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1614,6 +1614,7 @@ def __init__(self, self.weight = nn.Parameter( torch.stack((torch.ones(num_channels, params_per_channel), torch.zeros(num_channels, params_per_channel)), dim=0)) + self.weight_proj = nn.Linear(params_per_channel, params_per_channel) self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) # self.weight: (2, num_channels, params_per_channel)( self.min_pad = min_pad @@ -1650,7 +1651,7 @@ def upsample_weight(self, N: int) -> Tensor: # N is the desired number of frequencies of weight, so we return # a complex weight of shape (num_channels, N). - weight = self.weight + weight = self.weight_proj(self.weight) num_channels = weight.shape[0] // 2 # the following may not be ideal, we'll see. weight = torch.nn.functional.upsample(weight, N, mode='linear', align_corners=True) From 1c839041617b67db20605f8c35b7b0b1e282f5bb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Nov 2025 14:30:12 +0800 Subject: [PATCH 0689/1191] Remove the half-normalization of weight magnitudes in FftConv. --- egs/librispeech/ASR/zipformer/zipformer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 940c132db9..6614973879 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1621,8 +1621,7 @@ def __init__(self, def forward(self, - x: Tensor, - transpose: bool = False) -> Tensor: + x: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape n = round_up_to_power_of_two(seq_len + self.min_pad) @@ -1633,11 +1632,6 @@ def forward(self, weight = self.upsample_weight(N) eps = 1.0e-05 - # half-normalize the weight. - weight = weight / (weight.abs() + eps).sqrt() - if transpose: - # reverse the time direction of the kernel. - weight = weight.conj() # weight: (num_channels, N) x = x * weight.t().unsqueeze(1) From 50360adf78c34e8d750f8eee02faa0dc2d108b24 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Nov 2025 17:00:04 +0800 Subject: [PATCH 0690/1191] Halve num params but keep the num points the same by having projection increase dimension by 2. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fbfd6b341f..5f95b7aaef 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="63,63,63,63", + default="32", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 940c132db9..af787f3e74 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1611,10 +1611,10 @@ def __init__(self, min_pad: int = 32): super().__init__() # initialize to identity function. - self.weight = nn.Parameter( torch.stack((torch.ones(num_channels, params_per_channel), - torch.zeros(num_channels, params_per_channel)), - dim=0)) - self.weight_proj = nn.Linear(params_per_channel, params_per_channel) + self.weight = nn.Parameter(torch.ones(num_channels, params_per_channel)) + + # the factor of 2 is for (sin, cos)a. + self.weight_proj = nn.Linear(params_per_channel, 2 * params_per_channel) self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) # self.weight: (2, num_channels, params_per_channel)( self.min_pad = min_pad @@ -1652,11 +1652,17 @@ def upsample_weight(self, N: int) -> Tensor: # a complex weight of shape (num_channels, N). weight = self.weight_proj(self.weight) - num_channels = weight.shape[0] // 2 + # weight: (num_channels, 2 * params_per_channel) + num_channels = weight.shape[0] + params_per_channel = weight.shape[1] // 2 + weight = weight.reshape(num_channels, 2, params_per_channel) + # the following may not be ideal, we'll see. + # in the following, num_channels is interpreted by upsample() as batch and 2 as channels but this + # does not matter as they are treated the same by upsample() weight = torch.nn.functional.upsample(weight, N, mode='linear', align_corners=True) - weight = torch.view_as_complex(weight.permute(1, 2, 0).contiguous()) + weight = torch.view_as_complex(weight.permute(0, 2, 1).contiguous()) # weight: (num_channels, N) return weight From 06075fcf3be53bd2b14b33404933f8082e304b49 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Nov 2025 17:38:58 +0800 Subject: [PATCH 0691/1191] Bug fix --- egs/librispeech/ASR/zipformer/zipformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index af787f3e74..9447718869 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1687,8 +1687,6 @@ def __init__( ) -> None: """Construct a ConvolutionModule object.""" super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 bottleneck_dim = channels self.causal = causal @@ -1707,7 +1705,6 @@ def __init__( self.activation2 = Identity() # for diagnostics - assert kernel_size % 2 == 1 self.fft_conv1 = FftModule(num_channels=bottleneck_dim, params_per_channel=kernel_size, From 9bc3133c9f6079b513b10524856346ce118f947a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Nov 2025 18:16:17 +0800 Subject: [PATCH 0692/1191] Change weight initialization, removing symmetry. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9447718869..7a19108924 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1611,7 +1611,7 @@ def __init__(self, min_pad: int = 32): super().__init__() # initialize to identity function. - self.weight = nn.Parameter(torch.ones(num_channels, params_per_channel)) + self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) # the factor of 2 is for (sin, cos)a. self.weight_proj = nn.Linear(params_per_channel, 2 * params_per_channel) From f98a3630a7dbb70f89e70f8c3e21de41f41d1fad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Nov 2025 10:31:47 +0800 Subject: [PATCH 0693/1191] Remove second fft_conv and middle activation --- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d4c6df8bca..9536210dcd 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1460,7 +1460,7 @@ def __init__( bias=True, out_groups=num_heads) self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 ) f = max(1.0, embed_dim / (num_heads * value_head_dim)) @@ -1700,19 +1700,10 @@ def __init__( self.activation2 = Identity() # for diagnostics - self.fft_conv1 = FftModule(num_channels=bottleneck_dim, + self.fft_conv = FftModule(num_channels=bottleneck_dim, params_per_channel=kernel_size, min_pad=32) - self.fft_conv2 = FftModule(num_channels=bottleneck_dim, - params_per_channel=kernel_size, - min_pad=32) - - self.activation3 = SwashR() - - self.middle_bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) - self.out_bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) - self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, @@ -1753,11 +1744,7 @@ def forward( x = self.activation2(x) # identity #x: (time, batch, channels) - x = self.fft_conv1(x) - - x = self.activation3(x) - - x = self.fft_conv2(x) + x = self.fft_conv(x) x = self.out_proj(x) # (time, batch, channels) From 396c19b7fa886d0de1b725f31a3b43530dbb3354 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Nov 2025 11:13:04 +0800 Subject: [PATCH 0694/1191] Reduce num-heads from 4 to 3 and increase value-head-dim from 20 to 32. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5f95b7aaef..ad5574d2a7 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -220,7 +220,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-heads", type=str, - default="4,4,4,4", + default="3", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) @@ -241,7 +241,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--value-head-dim", type=str, - default="20", + default="32", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From cfaaf330f21a96621ba02934cd209fa4ab7bedf8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Nov 2025 11:44:52 +0800 Subject: [PATCH 0695/1191] Introduce shared projection of params of conv modules, stored at stack level. --- egs/librispeech/ASR/zapformer/train.py | 7 +- egs/librispeech/ASR/zipformer/zipformer.py | 137 ++++++++------------- 2 files changed, 55 insertions(+), 89 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ad5574d2a7..946487c569 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -260,11 +260,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--cnn-module-kernel", + "--conv-params", type=str, default="32", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", + help="Parameters per channel of convolution kernels", ) parser.add_argument( @@ -720,7 +719,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: pos_dim=params.pos_dim, num_heads=lookup(params, "num_heads"), feedforward_multiple=lookup(params, "feedforward_multiple"), - cnn_module_kernel=lookup(params, "cnn_module_kernel"), + conv_params=lookup(params, "conv_params"), causal=params.causal, chunk_size=lookup(params, "chunk_size"), left_context_frames=lookup(params, "left_context_frames"), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9536210dcd..8e75e96542 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -78,7 +78,7 @@ class Zipformer2(EncoderInterface): num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + conv_params (int or Tuple[int])): Kernel size of convolution module pos_dim (int): the dimension of each positional-encoding vector prior to projection, e.g. 128. @@ -92,8 +92,7 @@ class Zipformer2(EncoderInterface): the chunk size will be randomly chosen from this list. -1 means no chunking. left_context_frames: (list of int): determines the number of left- context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. + chunks. """ def __init__( self, @@ -107,7 +106,7 @@ def __init__( value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, feedforward_multiple: Union[int, Tuple[int]] = 4, - cnn_module_kernel: Union[int, Tuple[int]] = 31, + conv_params: Union[int, Tuple[int]] = 31, pos_dim: int = 192, causal: bool = False, chunk_size: Tuple[int] = [-1], @@ -137,7 +136,7 @@ def _to_tuple(x): pos_head_dim = _to_tuple(pos_head_dim) self.num_heads = num_heads = _to_tuple(num_heads) feedforward_multiple = _to_tuple(feedforward_multiple) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + self.conv_params = conv_params = _to_tuple(conv_params) self.causal = causal self.chunk_size = chunk_size @@ -164,7 +163,7 @@ def _to_tuple(x): pos_head_dim=pos_head_dim[i], value_head_dim=value_head_dim[i], feedforward_multiple=feedforward_multiple[i], - cnn_module_kernel=cnn_module_kernel[i], + conv_params=conv_params[i], causal=causal, ) @@ -308,7 +307,7 @@ def _get_attn_mask( num_encoders = len(self.encoder_dim) assert all( chunk_size * left_context_chunks - >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + >= (self.conv_params[i] // 2) * self.downsampling_factor[i] for i in range(num_encoders) ) else: @@ -411,7 +410,7 @@ def get_init_states( value_dim = self.value_head_dim[i] * num_heads downsample_left = self.left_context_frames[0] // ds nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 + conv_left_pad = self.cnn_module_kernel[i] // 2 # will be error. have to figure this out. for layer in range(num_layers): cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( device @@ -503,7 +502,7 @@ class Zipformer2EncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (required). feedforward_multiple: determines the hidden dimension of the feedforward module - cnn_module_kernel (int): Kernel size of convolution module (default=31). + conv_params (int): params per channel of convolution module Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) @@ -520,7 +519,7 @@ def __init__( pos_head_dim: int, value_head_dim: int, feedforward_multiple: int, - cnn_module_kernel: int = 31, + conv_params: int, causal: bool = False, ) -> None: super(Zipformer2EncoderLayer, self).__init__() @@ -548,7 +547,7 @@ def __init__( self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim) - self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) self.norm = ExpNorm(embed_dim) @@ -556,6 +555,7 @@ def __init__( def forward( self, src: Tensor, + weight_proj: Tensor, pos_emb: Tensor, chunk_size: int = -1, attn_mask: Optional[Tensor] = None, @@ -566,6 +566,7 @@ def forward( Pass the input through the encoder layer. Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + weight_proj: to be passed to the convolution modules, of shape (max_conv_length, conv_params) pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), @@ -596,7 +597,7 @@ def forward( src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + src = src + self.conv_module(src, weight_proj, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -771,6 +772,11 @@ def __init__( (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], dim=0)) + + conv_params = encoder_layer.conv_module.depthwise_conv.weight.shape[1] + max_conv_length = 255 + self.weight_proj = nn.Parameter((max_conv_length ** -0.5) * torch.randn(max_conv_length, conv_params)) + self.copy_bypass = Identity() @@ -813,10 +819,12 @@ def forward( min=-1.0, max=-0.5) src_with_bypass = residual_scale * src + weight_proj = self.weight_proj for i, mod in enumerate(self.layers): src = mod( src, + weight_proj, pos_emb, chunk_size=chunk_size, attn_mask=attn_mask, @@ -1604,63 +1612,42 @@ def round_up_to_power_of_two(x): return x -class FftModule(nn.Module): +class ProjDepthwiseConv(nn.Module): def __init__(self, num_channels: int, params_per_channel: int, - min_pad: int = 32): + bias: bool = True): super().__init__() # initialize to identity function. - self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) - - # the factor of 2 is for (sin, cos)a. - self.weight_proj = nn.Linear(params_per_channel, 2 * params_per_channel) - self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) - # self.weight: (2, num_channels, params_per_channel)( - self.min_pad = min_pad + self.weight = nn.Parameter((params_per_channel ** -0.5) * torch.randn(num_channels, params_per_channel)) + if bias: + self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) + else: + self.bias = None def forward(self, - x: Tensor) -> Tensor: + x: Tensor, + weight_proj: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape - - n = round_up_to_power_of_two(seq_len + self.min_pad) - x = torch.fft.rfft(x, n=n, dim=0, norm="ortho") - - N = x.shape[0] # N == n/2 + 1, the number of fourier components. - # x: (N, batch_size, num_channels) - - weight = self.upsample_weight(N) - eps = 1.0e-05 - # weight: (num_channels, N) - - x = x * weight.t().unsqueeze(1) - - x = torch.fft.irfft(x, dim=0, norm="ortho") - - return x[:seq_len] + self.bias - - - def upsample_weight(self, N: int) -> Tensor: - # N is the desired number of frequencies of weight, so we return - # a complex weight of shape (num_channels, N). - - weight = self.weight_proj(self.weight) - # weight: (num_channels, 2 * params_per_channel) - num_channels = weight.shape[0] - params_per_channel = weight.shape[1] // 2 - weight = weight.reshape(num_channels, 2, params_per_channel) - - # the following may not be ideal, we'll see. - # in the following, num_channels is interpreted by upsample() as batch and 2 as channels but this - # does not matter as they are treated the same by upsample() - weight = torch.nn.functional.upsample(weight, N, mode='linear', align_corners=True) - - weight = torch.view_as_complex(weight.permute(0, 2, 1).contiguous()) - # weight: (num_channels, N) - return weight - - + (_num_channels, params_per_channel) = self.weight.shape + assert weight_proj.shape[1] == params_per_channel + max_conv_length = weight_proj.shape[0] + assert max_conv_length % 2 == 1 + + # if convolution length is longer than seq_len, we can truncate the convolution. + truncate = max(max_conv_length - (seq_len - 1), 0) // 2 + if truncate > 0: + weight_proj = weight_proj[truncate:-truncate] + + weight = torch.matmul(self.weight, weight_proj.t()) + # weight: (num_channels, conv_width); conv_width is odd. + + x = x.permute(1, 2, 0) # (batch, channels, width) + weight = weight.unsqueeze(1) # (num_channels, 1, conv_width) + x = torch.nn.functional.conv1d(x, weight, self.bias, groups=num_channels, padding='same') + x = x.permute(2, 0, 1) # (seq, batch, channels) + return x class ConvolutionModule(nn.Module): @@ -1700,9 +1687,8 @@ def __init__( self.activation2 = Identity() # for diagnostics - self.fft_conv = FftModule(num_channels=bottleneck_dim, - params_per_channel=kernel_size, - min_pad=32) + self.depthwise_conv = ProjDepthwiseConv(bottleneck_dim, + kernel_size) self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1715,6 +1701,7 @@ def __init__( def forward( self, x: Tensor, + weight_proj: Tensor, src_key_padding_mask: Optional[Tensor] = None, chunk_size: int = -1, aux_loss_scale: float = 0.0, @@ -1723,6 +1710,7 @@ def forward( Args: x: Input tensor (#time, batch, channels). + weight_proj: tensor of shape (max_conv_length, kernel_size), with max_conv_length > kernel_size; expands the size of the convolution. src_key_padding_mask: the mask for the src keys per batch (optional): (batch, #time), contains True in masked positions. @@ -1730,8 +1718,6 @@ def forward( Tensor: Output tensor (#time, batch, channels). """ - - if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) @@ -1744,7 +1730,7 @@ def forward( x = self.activation2(x) # identity #x: (time, batch, channels) - x = self.fft_conv(x) + x = self.depthwise_conv(x, weight_proj) x = self.out_proj(x) # (time, batch, channels) @@ -1833,24 +1819,6 @@ def _test_zipformer_main(causal: bool = False): ) x_ # to remove flake8 warnings -def _test_fft_module(): - num_channels = 110 - f = FftModule(num_channels=num_channels, - params_per_channel=10, - min_pad=4) - - batch_size = 5 - seq_len = 50 - x = torch.randn(seq_len, batch_size, num_channels) - - y = f(x) - - def rms(a): - return (a**2).mean().item() - - print(f"rms(y)={rms(y)}, rms(x-y)={rms(x-y)}") - - if __name__ == "__main__": @@ -1859,4 +1827,3 @@ def rms(a): torch.set_num_interop_threads(1) _test_zipformer_main(False) _test_zipformer_main(True) - _test_fft_module() From 405b6b49db734e75acb611b88894883af04173d0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Nov 2025 13:34:01 +0800 Subject: [PATCH 0696/1191] Tighten residual_scale limits from 0.1..1 to 0.25..0.75; make stack-level residual_scales scalars. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8e75e96542..0afb1de289 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -601,7 +601,7 @@ def forward( src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) + residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) offset = (src - src_orig) * residual_scale offset = self.offset_scale_limiter(offset, aux_loss_scale) @@ -768,8 +768,8 @@ def __init__( self.num_layers = num_layers self.residual_scales = nn.Parameter( - torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), - (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], + torch.cat([ -1.0 * torch.ones(1), + (1. / num_layers) * torch.ones(num_layers) ], dim=0)) From 836b7e6b18e562db7a3ef220b8a9b3dafd79972c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Nov 2025 17:10:25 +0800 Subject: [PATCH 0697/1191] Fix regardin truncation of convolution. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0afb1de289..d51f36a67e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1636,7 +1636,7 @@ def forward(self, assert max_conv_length % 2 == 1 # if convolution length is longer than seq_len, we can truncate the convolution. - truncate = max(max_conv_length - (seq_len - 1), 0) // 2 + truncate = max(max_conv_length // 2 - (seq_len - 1), 0) if truncate > 0: weight_proj = weight_proj[truncate:-truncate] From 12d96bdf8830321722291afd267b90c2812189d9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Nov 2025 16:16:41 +0800 Subject: [PATCH 0698/1191] Implement circular summation of weights and circular padding mode. --- egs/librispeech/ASR/zipformer/zipformer.py | 51 ++++++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d51f36a67e..fd660b42b6 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1631,21 +1631,54 @@ def forward(self, weight_proj: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape (_num_channels, params_per_channel) = self.weight.shape - assert weight_proj.shape[1] == params_per_channel - max_conv_length = weight_proj.shape[0] - assert max_conv_length % 2 == 1 - # if convolution length is longer than seq_len, we can truncate the convolution. - truncate = max(max_conv_length // 2 - (seq_len - 1), 0) - if truncate > 0: - weight_proj = weight_proj[truncate:-truncate] + weight_proj = weight_proj.t() + # weight_proj: (params_per_channel, conv_length) + assert weight_proj.shape[0] == params_per_channel + conv_length = weight_proj.shape[1] + assert conv_length % 2 == 1 + + # if convolution length is longer than seq_len, we can truncate the convolution by + # wrapping it around (so it will be the same as if we did the full convolution + # with circular padding) + + if conv_length > seq_len: + wrapped_conv_length = seq_len + + # 'multiple' is the number of 'wraps' we sum over. this must be odd so + # that the original middle ends up in the middle after wrapping. + multiple = (conv_length + seq_len - 1) // seq_len + if multiple % 2 == 0: + multiple = multiple + 1 # need multiple to be odd. + padding = (seq_len * multiple) - conv_length + left_pad = padding // 2 + right_pad = padding - left_pad + weight_proj = torch.nn.functional.pad(weight_proj, (left_pad, right_pad)) + + weight_proj = weight_proj.reshape(params_per_channel, multiple, seq_len).sum(dim=1) + # weight_proj: (num_channels, seq_len) + if seq_len % 2 == 0: + # even-length convolution will cause efficiency problems for conv1d, so we pad + # the convolution with a zero on the left (which would have been the side that + # was made shorter by the uneven padding). The fact that it's zero won't matter + # because we'll just get the value from the wrapped around other side, due to + # circular padding. + weight_proj = torch.cat((torch.zeros(params_per_channel, 1, device=weight_proj.device, dtype=weight_proj.dtype), + weight_proj), dim=1) + conv_length = weight_proj.shape[1] + + + weight = torch.matmul(self.weight, weight_proj) + # weight: (num_channels, conv_length) ; note, conv_length may have been reduced to seq_len + 1 already. + padding = conv_length // 2 # note, conv_length will be odd. - weight = torch.matmul(self.weight, weight_proj.t()) # weight: (num_channels, conv_width); conv_width is odd. x = x.permute(1, 2, 0) # (batch, channels, width) weight = weight.unsqueeze(1) # (num_channels, 1, conv_width) - x = torch.nn.functional.conv1d(x, weight, self.bias, groups=num_channels, padding='same') + + x = torch.nn.functional.pad(x, (padding, padding), mode='circular') + x = torch.nn.functional.conv1d(x, weight, self.bias, groups=num_channels) x = x.permute(2, 0, 1) # (seq, batch, channels) return x From ca76dfd3298dddc665919062b65b96dff7dd3ebc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Nov 2025 17:23:51 +0800 Subject: [PATCH 0699/1191] Change to fft-based implementation of convolution --- egs/librispeech/ASR/zipformer/zipformer.py | 56 ++++++++++------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fd660b42b6..47303d754c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1642,43 +1642,37 @@ def forward(self, # wrapping it around (so it will be the same as if we did the full convolution # with circular padding) - if conv_length > seq_len: - wrapped_conv_length = seq_len - - # 'multiple' is the number of 'wraps' we sum over. this must be odd so - # that the original middle ends up in the middle after wrapping. - multiple = (conv_length + seq_len - 1) // seq_len - if multiple % 2 == 0: - multiple = multiple + 1 # need multiple to be odd. - padding = (seq_len * multiple) - conv_length - left_pad = padding // 2 - right_pad = padding - left_pad - weight_proj = torch.nn.functional.pad(weight_proj, (left_pad, right_pad)) - - weight_proj = weight_proj.reshape(params_per_channel, multiple, seq_len).sum(dim=1) - # weight_proj: (num_channels, seq_len) - if seq_len % 2 == 0: - # even-length convolution will cause efficiency problems for conv1d, so we pad - # the convolution with a zero on the left (which would have been the side that - # was made shorter by the uneven padding). The fact that it's zero won't matter - # because we'll just get the value from the wrapped around other side, due to - # circular padding. - weight_proj = torch.cat((torch.zeros(params_per_channel, 1, device=weight_proj.device, dtype=weight_proj.dtype), - weight_proj), dim=1) - conv_length = weight_proj.shape[1] + middle = conv_length // 2 + # pad the convolution so that its middle point is positioned at an exact multiple of seq_len, + # which will become position zero after circular summing; and so that the total length is an + # exact multiple of seq_len. + left_pad = (-middle) % seq_len # caution if you translate this into C, this relies on python's definition. + right_pad = (-(conv_length + left_pad)) % seq_len + weight_proj = torch.nn.functional.pad(weight_proj, (left_pad, right_pad)) + + weight_proj = weight_proj.reshape(params_per_channel, -1, seq_len).sum(dim=1) + # weight_proj: (num_channels, seq_len). Central point of conv is positioned + # at position zero. weight = torch.matmul(self.weight, weight_proj) - # weight: (num_channels, conv_length) ; note, conv_length may have been reduced to seq_len + 1 already. - padding = conv_length // 2 # note, conv_length will be odd. + # weight: (num_channels, seq_len). + + + x = x.permute(1, 2, 0) # (batch_size, num_channels, seq_len) + + both = torch.cat((x, weight.unsqueeze(0)), dim=0) + # both: (batch_size + 1, num_channels, seq_len) + + with torch.amp.autocast('cuda', enabled=False): + # do it in float32 because non power of two seq_len is not supported in half precision. + both = torch.fft.rfft(both.to(torch.float32), norm="ortho") - # weight: (num_channels, conv_width); conv_width is odd. + # multiplication in fourier space is the same as (circular) convolution. + x = both[:-1] * both[-1] - x = x.permute(1, 2, 0) # (batch, channels, width) - weight = weight.unsqueeze(1) # (num_channels, 1, conv_width) + x = torch.fft.irfft(x, norm="ortho") - x = torch.nn.functional.pad(x, (padding, padding), mode='circular') - x = torch.nn.functional.conv1d(x, weight, self.bias, groups=num_channels) x = x.permute(2, 0, 1) # (seq, batch, channels) return x From f00f15d6dd79754fee8e26285f7f2015e769c168 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Nov 2025 17:50:29 +0800 Subject: [PATCH 0700/1191] Repeat elements of the main sequence in the padding region instead of padding with zeroes. --- egs/librispeech/ASR/zipformer/zipformer.py | 33 ++++++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 47303d754c..de597ff36c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1745,9 +1745,6 @@ def forward( Tensor: Output tensor (#time, batch, channels). """ - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = self.in_proj(x) # (time, batch, 2*channels) x, s = x.chunk(2, dim=2) @@ -1757,12 +1754,42 @@ def forward( x = self.activation2(x) # identity #x: (time, batch, channels) + if src_key_padding_mask is not None: + x = self.repeat_in_padding(x, src_key_padding_mask) + x = self.depthwise_conv(x, weight_proj) x = self.out_proj(x) # (time, batch, channels) return x + + def repeat_in_padding(self, x, mask): + # repeats elements of x in the padding region, circularly as much as possible; + # the discontinuity between the ones that circularly repeat from the end and + # those that circularly repeat from the beginning is in the middle of the padding + # region. + + # x: (seq_len, batch_size, num_channels) + (batch_size, seq_len) = mask.shape + + seq_lengths = (~mask).to(torch.int64).sum(dim=1, keepdim=True) # (batch_size, 1) + pad_len = seq_len - seq_lengths + arange = torch.arange(seq_len, device=mask.device) + + # "mid" gives the index of the midpoint of the padding region after each sequence. + mid = (seq_lengths + seq_len) // 2 # mid: (batch_size, 1) + + src_index = torch.where(arange >= mid, arange - pad_len, arange) % seq_lengths + # src_index: (batch_size, seq_len) + + src_index = src_index.t().unsqueeze(-1).expand_as(x) + # src_index: (seq_len, batch_size, num_channels) + x = torch.gather(x, dim=0, index=src_index) + return x + + + def streaming_forward( self, x: Tensor, From 30b1f72f02c44d3c91bf9eb108c2b80fcd112ba6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Nov 2025 18:10:42 +0800 Subject: [PATCH 0701/1191] pass in n to irfft --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index de597ff36c..9e0992e538 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1671,7 +1671,7 @@ def forward(self, # multiplication in fourier space is the same as (circular) convolution. x = both[:-1] * both[-1] - x = torch.fft.irfft(x, norm="ortho") + x = torch.fft.irfft(x, n=seq_len, norm="ortho") x = x.permute(2, 0, 1) # (seq, batch, channels) return x From 81aeebb24cca5620be4adf88fe6c6aafa14c7386 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Nov 2025 18:41:53 +0800 Subject: [PATCH 0702/1191] Bug fix regardin length --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 47303d754c..818132fb35 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1671,7 +1671,7 @@ def forward(self, # multiplication in fourier space is the same as (circular) convolution. x = both[:-1] * both[-1] - x = torch.fft.irfft(x, norm="ortho") + x = torch.fft.irfft(x, n=seq_len, norm="ortho") x = x.permute(2, 0, 1) # (seq, batch, channels) return x From ee819316cffd7448e75a2a3b4b7d2ad9a5695cb3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Nov 2025 20:26:46 +0800 Subject: [PATCH 0703/1191] Bug fixes and test the fft based convolution code. --- egs/librispeech/ASR/zipformer/zipformer.py | 93 +++++++++++++++++++++- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 818132fb35..2c8bd5a2be 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1612,6 +1612,21 @@ def round_up_to_power_of_two(x): return x +class ProjDepthwiseConv(nn.Module): + def __init__(self, + num_channels: int, + params_per_channel: int, + bias: bool = True): + super().__init__() + # initialize to identity function. + self.weight = nn.Parameter((params_per_channel ** -0.5) * torch.randn(num_channels, params_per_channel)) + if bias: + self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) + else: + self.bias = None + + + class ProjDepthwiseConv(nn.Module): def __init__(self, num_channels: int, @@ -1629,6 +1644,76 @@ def __init__(self, def forward(self, x: Tensor, weight_proj: Tensor) -> Tensor: + return self.forward_fft(x, weight_proj) + #a = self.forward_fft(x, weight_proj) + #b = self.forward_conv(x, weight_proj) + #diff = a - b + #def rms(x): + # return (x**2).mean().sqrt() + #print(f"size={x.shape}, rms(a)={rms(a)}, rms(b)={rms(b)}, rms(diff)={rms(diff)}, rms(diff-last)={rms(diff[-1])}, rms(diff-first)={rms(diff[0])}") + #return a + + + def forward_conv(self, + x: Tensor, + weight_proj: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + (_num_channels, params_per_channel) = self.weight.shape + + weight_proj = weight_proj.t() + # weight_proj: (params_per_channel, conv_length) + assert weight_proj.shape[0] == params_per_channel + conv_length = weight_proj.shape[1] + assert conv_length % 2 == 1 + + # if convolution length is longer than seq_len, we can truncate the convolution by + # wrapping it around (so it will be the same as if we did the full convolution + # with circular padding) + + if conv_length > seq_len: + wrapped_conv_length = seq_len + + # 'multiple' is the number of 'wraps' we sum over. this must be odd so + # that the original middle ends up in the middle after wrapping. + multiple = (conv_length + seq_len - 1) // seq_len + if multiple % 2 == 0: + multiple = multiple + 1 # need multiple to be odd. + padding = (seq_len * multiple) - conv_length + left_pad = padding // 2 + right_pad = padding - left_pad + weight_proj = torch.nn.functional.pad(weight_proj, (left_pad, right_pad)) + + weight_proj = weight_proj.reshape(params_per_channel, multiple, seq_len).sum(dim=1) + # weight_proj: (num_channels, seq_len) + if seq_len % 2 == 0: + # even-length convolution will cause efficiency problems for conv1d, so we pad + # the convolution with a zero on the left (which would have been the side that + # was made shorter by the uneven padding). The fact that it's zero won't matter + # because we'll just get the value from the wrapped around other side, due to + # circular padding. + weight_proj = torch.cat((torch.zeros(params_per_channel, 1, device=weight_proj.device, dtype=weight_proj.dtype), + weight_proj), dim=1) + conv_length = weight_proj.shape[1] + + + weight = torch.matmul(self.weight, weight_proj) + # weight: (num_channels, conv_length) ; note, conv_length may have been reduced to seq_len + 1 already. + padding = conv_length // 2 # note, conv_length will be odd. + + # weight: (num_channels, conv_width); conv_width is odd. + + x = x.permute(1, 2, 0) # (batch, channels, width) + weight = weight.unsqueeze(1) # (num_channels, 1, conv_width) + + x = torch.nn.functional.pad(x, (padding, padding), mode='circular') + x = torch.nn.functional.conv1d(x, weight, self.bias, groups=num_channels) + x = x.permute(2, 0, 1) # (seq, batch, channels) + return x + + + def forward_fft(self, + x: Tensor, + weight_proj: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape (_num_channels, params_per_channel) = self.weight.shape @@ -1666,14 +1751,14 @@ def forward(self, with torch.amp.autocast('cuda', enabled=False): # do it in float32 because non power of two seq_len is not supported in half precision. - both = torch.fft.rfft(both.to(torch.float32), norm="ortho") + both = torch.fft.rfft(both.to(torch.float32)) # multiplication in fourier space is the same as (circular) convolution. - x = both[:-1] * both[-1] + x = both[:-1] * both[-1].conj() - x = torch.fft.irfft(x, n=seq_len, norm="ortho") + x = torch.fft.irfft(x, n=seq_len) - x = x.permute(2, 0, 1) # (seq, batch, channels) + x = x.permute(2, 0, 1) + self.bias # (seq, batch, channels) return x From 35159cb476fc4de4feab01e67cbf7dbf171a785e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Nov 2025 14:56:32 +0800 Subject: [PATCH 0704/1191] Change optim.py to decay larger singular values more strongly. --- egs/librispeech/ASR/zipformer/optim.py | 79 +++++++++++++++++++++----- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 008c176f41..1eb3b16e23 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -150,7 +150,62 @@ def basic_step(group, p, state, grad): -def momentum_step(group, p, state, grad): +def scale_tensor_by(x, beta1): + # internal function called by scale_by; works on "properly shaped" x. is similar in efffect + # to x.mul_(beta1) but decays the larger singular values of the matrices in x more than the smaller + # ones. + if x.ndim > 3: + # each tensor in the batch has more than two dimensions. + # reshape to be like a batch of matrices. + # note: x.shape[0] is batch dimension. + if x.shape[1] > x.shape[-1]: + xr = x.reshape(x.shape[0], x.shape[1], -1) + else: + xr = x.reshape(x.shape[0], -1, x.shape[-1]) + scale_tensor_by(xr, beta1) + x[:] = xr.reshape(*x.shape) + return + if x.shape[1] > x.shape[2]: + xr = x.permute(0, 2, 1) + scale_tensor_by(xr, beta1) + x[:] = xr.permute(0, 2, 1) + return + (batch_size, rows, cols) = x.shape # and rows <= cols + + x2 = torch.matmul(x, x.permute(0, 2, 1)) + # x2: (batch_size, rows, rows) + eps = 1.0e-10 + (batch_stride, stride1, stride2) = x2.stride() + x2_diag_sum = torch.as_strided(x2, (batch_size, rows), (batch_stride, stride1 + stride2)).sum() # (batch_size,) + x2_sq_sum = (x2 ** 2).sum(dim=(1, 2)) # (batch_size,) + scale = x2_diag_sum / x2_sq_sum + + x_scaled = torch.matmul(x2, x) * scale[:, None, None] + #if True: + # dot_prod1 = (x * x).sum(dim=(1, 2)) + # dot_prod2 = (x * x_scaled).sum(dim=(1, 2)) + # these dot products are the same when printed, as intended. + # print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") + + x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. + + + +def scale_by(x, beta1, shape): + # x is a tensor of shape (batch_size, per_tensor_numel + 1), + # where the + 1 is for the log scale. + # 'shape' is the shape of x before we flattened it. + # if x represents a bias or a scalar, just do x.mul_(beta1). + # note: the first dim of x is a "batch dim" which is a batch of same-shaped tensors. + if len(shape) <= 2: + x.mul_(beta1) + return + + scale_tensor_by(x[:, :-1].reshape(*shape), beta1) + x[:, -1].mul_(beta1) + + +def momentum_step(group, p, state, grad, shape): delta = basic_step(group, p, state, grad) #beta1 = group["betas"][0] @@ -168,7 +223,9 @@ def momentum_step(group, p, state, grad): stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) state["delta"] = stored_delta - stored_delta.mul_(beta1).add_(delta) + + stored_delta.add_(delta) + scale_by(stored_delta, beta1, shape) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) @@ -193,14 +250,13 @@ def forward_transform_param(group, p): abs_sum = p_flat.abs().sum(dim=1, keepdim=True) min_abs_sum = min_scale * numel # if abs_sum is less than this we pad with an extra element. abs_sum_clamped = abs_sum.clamp(min=min_abs_sum) - pad = (abs_sum_clamped - abs_sum) scale = (abs_sum_clamped / numel) # must be nonzero thanks to min_abs_sum # scaling_lr_scale is to control the learning-rate of scaling factors. # log_scale controls the overall scale of this tensor log_scale = (1 / group["scaling_lr_scale"]) * scale.log() - ans = torch.cat((p_flat / scale, pad / scale, log_scale), dim=1) + ans = torch.cat((p_flat / scale, log_scale), dim=1) return ans def reverse_transform_param(group, p, orig_shape): @@ -208,13 +264,12 @@ def reverse_transform_param(group, p, orig_shape): if p.numel() == batch_size: return (p * group["scalar_lr_scale"]).reshape(*orig_shape) # numel is num elements of each parameter tensor in the batch. - numel = p.shape[1] - 2 - p_padded = p[:, :numel+1] # orig tensor plus one padding element + numel = p.shape[1] - 1 is_weight = (len(orig_shape) > 2) max_scale = group["weight_max_scale"] if is_weight else group["bias_max_scale"] min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] - log_scale = (p[:, numel+1:numel+2] * group["scaling_lr_scale"]) + log_scale = (p[:, numel:] * group["scaling_lr_scale"]) scaling_lr = group["scaling_lr_scale"] * group["lr"] @@ -224,7 +279,7 @@ def reverse_transform_param(group, p, orig_shape): log_scale = ((log_scale - log_scale_default) * (1. - group["scale_decay"] * scaling_lr)) + log_scale_default scale = log_scale.exp().clamp(min=min_scale, max=max_scale) - q = p_padded[:, :-1] * scale # the :-1 is to remove the padding element. + q = p[:, :-1] * scale q = q.reshape(*orig_shape) return q @@ -245,7 +300,7 @@ def scaling_step(group, p, state, grad): p_shape = p.shape p_flat, grad_flat = forward_transform_param_and_grad(group, p, grad) - p_flat += momentum_step(group, p_flat, state, grad_flat) + p_flat += momentum_step(group, p_flat, state, grad_flat, p_shape) p = reverse_transform_param(group, p_flat, p.shape) return p @@ -890,9 +945,6 @@ def __init__( ) super().__init__(params, defaults) - self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook) - - def __setstate__(self, state): super(TransformedAdam, self).__setstate__(state) @@ -1258,8 +1310,7 @@ def _test_sched3(): m = torch.nn.Linear(100, 100) optim = TransformedAdam(m.parameters(), lr=0.03) - scheduler = Sched3(optim, lr_batches=100, p=0.8, verbose=True, warmup_batches=20) - + scheduler = Sched3(optim, lr_batches=100, power=0.5, verbose=True, warmup_batches=20) for step in range(300): x = torch.randn(200, 100).detach() From 42756e82e3bc3181ab431bd950e14a04d9b60852 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Nov 2025 15:11:27 +0800 Subject: [PATCH 0705/1191] Decrease warmup_start --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 946487c569..b1ef485a63 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1402,7 +1402,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), warmup_start=0.25, power=0.5) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 07d2463c59845422d158a38b1bfa28d98b6fa6de Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Nov 2025 15:36:05 +0800 Subject: [PATCH 0706/1191] Fix bug regarding batch sum; revert previous change to warmup_start. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/optim.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b1ef485a63..946487c569 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1402,7 +1402,7 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), warmup_start=0.25, power=0.5) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 1eb3b16e23..c544caec4b 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -176,15 +176,19 @@ def scale_tensor_by(x, beta1): # x2: (batch_size, rows, rows) eps = 1.0e-10 (batch_stride, stride1, stride2) = x2.stride() - x2_diag_sum = torch.as_strided(x2, (batch_size, rows), (batch_stride, stride1 + stride2)).sum() # (batch_size,) + # x_squared_sum, equivalent to (x**2).sum(dim=(1, 2)), but faster to compute. + x2_diag_sum = torch.as_strided(x2, (batch_size, rows), (batch_stride, stride1 + stride2)).sum(dim=1) # (batch_size,) + x2_sq_sum = (x2 ** 2).sum(dim=(1, 2)) # (batch_size,) scale = x2_diag_sum / x2_sq_sum x_scaled = torch.matmul(x2, x) * scale[:, None, None] + + #x_scaled_squared_sum = (x ** 2).sum(dim=(1, 2 + #if True: # dot_prod1 = (x * x).sum(dim=(1, 2)) # dot_prod2 = (x * x_scaled).sum(dim=(1, 2)) - # these dot products are the same when printed, as intended. # print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. From 51c36290c91cc1262e19a033e9e357bb34d6f2b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Nov 2025 16:43:22 +0800 Subject: [PATCH 0707/1191] Mave max dimension of 1024 for matrix multiplication in optimizer. --- egs/librispeech/ASR/zipformer/optim.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c544caec4b..9d48e2ec75 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -170,6 +170,20 @@ def scale_tensor_by(x, beta1): scale_tensor_by(xr, beta1) x[:] = xr.permute(0, 2, 1) return + + # avoid matrix multiplies by any dimensions that are too large. + max_dim = 1024 + if x.shape[1] > max_dim: + n = x.shape[1] + for divisor in range(2, 100): + if n % divisor == 0 and n // divisor <= max_dim: + xr = x.reshape(x.shape[0], n // divisor, divisor * x.shape[2]) + scale_tensor_by(xr, beta1) + x[:] = xr.reshape(*x.shape) + return + # if no divisor worked, just continue. + + (batch_size, rows, cols) = x.shape # and rows <= cols x2 = torch.matmul(x, x.permute(0, 2, 1)) @@ -190,11 +204,9 @@ def scale_tensor_by(x, beta1): # dot_prod1 = (x * x).sum(dim=(1, 2)) # dot_prod2 = (x * x_scaled).sum(dim=(1, 2)) # print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") - x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. - def scale_by(x, beta1, shape): # x is a tensor of shape (batch_size, per_tensor_numel + 1), # where the + 1 is for the log scale. From 484defc06d1532fc35f653b331da79586eec3c3b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Nov 2025 21:31:53 +0800 Subject: [PATCH 0708/1191] Anti-interpolate with baseline form of decay of delta, coeff=-0.5. --- egs/librispeech/ASR/zipformer/optim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9d48e2ec75..15c2331cab 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -204,7 +204,13 @@ def scale_tensor_by(x, beta1): # dot_prod1 = (x * x).sum(dim=(1, 2)) # dot_prod2 = (x * x_scaled).sum(dim=(1, 2)) # print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") - x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. + + # interpolate with the basic form of decay as a compromise. + # a negative interpolation coefficient was more promising in a test, trying that first. + baseline_coeff = -0.5 + x3_coeff = 1. - baseline_coeff + x.mul_(baseline_coeff * beta1 + x3_coeff) + x.add_(x_scaled, alpha=x3_coeff * (beta1-1)) # note: negative alpha. def scale_by(x, beta1, shape): From aea81b3e39fafe98f4c2531f6b5f6c917c67c8e1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Nov 2025 22:06:03 +0800 Subject: [PATCH 0709/1191] After not promising results, change baseline_coeff from -0.5 to 0.25. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 15c2331cab..14fb7757d4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -207,7 +207,7 @@ def scale_tensor_by(x, beta1): # interpolate with the basic form of decay as a compromise. # a negative interpolation coefficient was more promising in a test, trying that first. - baseline_coeff = -0.5 + baseline_coeff = 0.25 x3_coeff = 1. - baseline_coeff x.mul_(baseline_coeff * beta1 + x3_coeff) x.add_(x_scaled, alpha=x3_coeff * (beta1-1)) # note: negative alpha. From a2a85e2edb1014daa3b7e744805dfe0f48a22311 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Nov 2025 14:24:15 +0800 Subject: [PATCH 0710/1191] For tensors more than 1024, use a better method of reshaping (use the batch index); implement interpolation-with-baseline via the step count with higher coeff (0.25->0.33). --- egs/librispeech/ASR/zipformer/optim.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 14fb7757d4..f35fe90732 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -163,12 +163,14 @@ def scale_tensor_by(x, beta1): else: xr = x.reshape(x.shape[0], -1, x.shape[-1]) scale_tensor_by(xr, beta1) - x[:] = xr.reshape(*x.shape) + if not xr.storage() is x.storage(): + x[:] = xr.reshape(*x.shape) return if x.shape[1] > x.shape[2]: xr = x.permute(0, 2, 1) scale_tensor_by(xr, beta1) - x[:] = xr.permute(0, 2, 1) + if not xr.storage() is x.storage(): + x[:] = xr.permute(0, 2, 1) return # avoid matrix multiplies by any dimensions that are too large. @@ -177,9 +179,10 @@ def scale_tensor_by(x, beta1): n = x.shape[1] for divisor in range(2, 100): if n % divisor == 0 and n // divisor <= max_dim: - xr = x.reshape(x.shape[0], n // divisor, divisor * x.shape[2]) + xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) scale_tensor_by(xr, beta1) - x[:] = xr.reshape(*x.shape) + if not xr.storage() is x.storage(): + x[:] = xr.reshape(*x.shape) return # if no divisor worked, just continue. @@ -205,12 +208,7 @@ def scale_tensor_by(x, beta1): # dot_prod2 = (x * x_scaled).sum(dim=(1, 2)) # print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") - # interpolate with the basic form of decay as a compromise. - # a negative interpolation coefficient was more promising in a test, trying that first. - baseline_coeff = 0.25 - x3_coeff = 1. - baseline_coeff - x.mul_(baseline_coeff * beta1 + x3_coeff) - x.add_(x_scaled, alpha=x3_coeff * (beta1-1)) # note: negative alpha. + x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. def scale_by(x, beta1, shape): @@ -247,7 +245,12 @@ def momentum_step(group, p, state, grad, shape): stored_delta.add_(delta) - scale_by(stored_delta, beta1, shape) + if step % 3 == 0: + # every third step, just do a normal decay, this is an efficient way of + # doing a kind of interpolation with the fourth-power regularization. + stored_delta.mul_(beta1) + else: + scale_by(stored_delta, beta1, shape) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) From e94a5cee73f5138adfbaefb9ceacb3db0c1e55fe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Nov 2025 15:06:27 +0800 Subject: [PATCH 0711/1191] Introduce scale of 1./(2. + position.abs()) into the weight_proj for the convolutions. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ce01f72db8..aac061d1e7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -775,7 +775,10 @@ def __init__( conv_params = encoder_layer.conv_module.depthwise_conv.weight.shape[1] max_conv_length = 255 - self.weight_proj = nn.Parameter((max_conv_length ** -0.5) * torch.randn(max_conv_length, conv_params)) + self.weight_proj = nn.Parameter(torch.randn(max_conv_length, conv_params)) + # scale weight_proj with a scale that's smaller for 'further-away-from-the-center' positions, since these positions + # will tend to have smaller weights. + self.register_buffer('weight_proj_scale', 1. / (2. + (torch.arange(conv_params) - (conv_params // 2)).abs())) self.copy_bypass = Identity() @@ -819,7 +822,7 @@ def forward( min=-1.0, max=-0.5) src_with_bypass = residual_scale * src - weight_proj = self.weight_proj + weight_proj = self.weight_proj * self.weight_proj_scale for i, mod in enumerate(self.layers): src = mod( From d6a6b1a5ca247a3fa2855bfea29a07dc41871c4b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Nov 2025 15:35:07 +0800 Subject: [PATCH 0712/1191] Bug fix --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index aac061d1e7..4d3dba83a0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -778,7 +778,7 @@ def __init__( self.weight_proj = nn.Parameter(torch.randn(max_conv_length, conv_params)) # scale weight_proj with a scale that's smaller for 'further-away-from-the-center' positions, since these positions # will tend to have smaller weights. - self.register_buffer('weight_proj_scale', 1. / (2. + (torch.arange(conv_params) - (conv_params // 2)).abs())) + self.register_buffer('weight_proj_scale', (1. / (2. + (torch.arange(max_conv_length) - (max_conv_length // 2)).abs())).unsqueeze(-1)) self.copy_bypass = Identity() From b2f8f9c66837a82f60451d481349849695b65261 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Nov 2025 16:02:21 +0800 Subject: [PATCH 0713/1191] Increase beta1 from 0.995 to 0.998. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index f35fe90732..6683b1c1f6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -453,7 +453,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.995, + beta1=0.998, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, From 50e5911719d4946678b13910256efd5298f5ebca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 16 Nov 2025 21:00:05 +0800 Subject: [PATCH 0714/1191] Increase beta1 from 0.998 to 0.999; make beta1 warmup slower, changing factor of 0.25 to 0.2 --- egs/librispeech/ASR/zipformer/optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6683b1c1f6..3396a813b3 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -232,7 +232,7 @@ def momentum_step(group, p, state, grad, shape): lr = group["lr"] step = state["step"] - beta1 = min(group["beta1"], 1. - 1. / (10. + 0.25 * step)) + beta1 = min(group["beta1"], 1. - 1. / (10. + 0.2 * step)) direct = group["direct"] try: @@ -453,7 +453,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.998, + beta1=0.999, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, @@ -936,7 +936,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.995, + beta1=0.999, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, From d977ba67bc266545f5ad5ac4459f706c05940ebe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 09:58:22 +0800 Subject: [PATCH 0715/1191] Increase beta1 from 0.999 to 0.9995. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3396a813b3..e23b41dfa5 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -453,7 +453,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.999, + beta1=0.9995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, From 3992a0e53caa04c7a1520e92fd2fd5eb2278267f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 12:04:12 +0800 Subject: [PATCH 0716/1191] Increase value-head-dim from 32 to 48. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 946487c569..4c815c874a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -241,7 +241,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--value-head-dim", type=str, - default="32", + default="48", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From d5f51cbc1790ca1f6abb7dbd1fc72ba1e5396908 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 15:25:13 +0800 Subject: [PATCH 0717/1191] Refactor optimizer, changing scale update and scaling update. --- egs/librispeech/ASR/zipformer/optim.py | 170 ++++++++++--------------- 1 file changed, 68 insertions(+), 102 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index e23b41dfa5..372e4fab52 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -124,7 +124,7 @@ def batched_params(self, param_group, group_params_names): -def basic_step(group, p, state, grad): +def basic_step(group, state, grad): # computes basic Adam normalized-grad using beta2 (dividing by gradient stddev) only. no momentum yet. beta2 = group["beta2"] eps = group["eps"] @@ -133,7 +133,7 @@ def basic_step(group, p, state, grad): exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) or (batch_size, 1, [1,..]) except KeyError: assert state["step"] < 2 - exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) state["exp_avg_sq"] = exp_avg_sq exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -150,10 +150,14 @@ def basic_step(group, p, state, grad): -def scale_tensor_by(x, beta1): - # internal function called by scale_by; works on "properly shaped" x. is similar in efffect +def scale_by(x, beta1): + # This is similar in efffect # to x.mul_(beta1) but decays the larger singular values of the matrices in x more than the smaller # ones. + if x.ndim <= 2: + x.mul_(beta1) + return + if x.ndim > 3: # each tensor in the batch has more than two dimensions. # reshape to be like a batch of matrices. @@ -162,13 +166,13 @@ def scale_tensor_by(x, beta1): xr = x.reshape(x.shape[0], x.shape[1], -1) else: xr = x.reshape(x.shape[0], -1, x.shape[-1]) - scale_tensor_by(xr, beta1) + scale_by(xr, beta1) if not xr.storage() is x.storage(): x[:] = xr.reshape(*x.shape) return if x.shape[1] > x.shape[2]: xr = x.permute(0, 2, 1) - scale_tensor_by(xr, beta1) + scale_by(xr, beta1) if not xr.storage() is x.storage(): x[:] = xr.permute(0, 2, 1) return @@ -180,7 +184,7 @@ def scale_tensor_by(x, beta1): for divisor in range(2, 100): if n % divisor == 0 and n // divisor <= max_dim: xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) - scale_tensor_by(xr, beta1) + scale_by(xr, beta1) if not xr.storage() is x.storage(): x[:] = xr.reshape(*x.shape) return @@ -211,24 +215,8 @@ def scale_tensor_by(x, beta1): x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. -def scale_by(x, beta1, shape): - # x is a tensor of shape (batch_size, per_tensor_numel + 1), - # where the + 1 is for the log scale. - # 'shape' is the shape of x before we flattened it. - # if x represents a bias or a scalar, just do x.mul_(beta1). - # note: the first dim of x is a "batch dim" which is a batch of same-shaped tensors. - if len(shape) <= 2: - x.mul_(beta1) - return - - scale_tensor_by(x[:, :-1].reshape(*shape), beta1) - x[:, -1].mul_(beta1) - - -def momentum_step(group, p, state, grad, shape): - delta = basic_step(group, p, state, grad) - - #beta1 = group["betas"][0] +def momentum_step(group, state, grad): + delta = basic_step(group, state, grad) lr = group["lr"] step = state["step"] @@ -240,7 +228,7 @@ def momentum_step(group, p, state, grad, shape): except KeyError as e: assert step < 2 # scalar. use conventional momentum. - stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) state["delta"] = stored_delta @@ -250,85 +238,74 @@ def momentum_step(group, p, state, grad, shape): # doing a kind of interpolation with the fourth-power regularization. stored_delta.mul_(beta1) else: - scale_by(stored_delta, beta1, shape) + scale_by(stored_delta, beta1) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) +def basic_momentum_step(group, state, grad, lr, beta): + delta = basic_step(group, state, grad) -def forward_transform_param(group, p): - """ - Returns a transformed version of the batch of parameters (dimension 0 of p is the batch - of same-shaped parameters). - The transformation is from a parameter to a (parameter-direction, log-weight), where - parameter-direction has unit RMS value and log-weight - """ - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel == 1: - # scalar parameters are treated specially. scalar_lr_scale is to control - # the learning-rate of scalars. - return p.reshape(batch_size, 1) / group["scalar_lr_scale"] + step = state["step"] + try: + stored_delta = state["delta"] + except KeyError as e: + assert step < 2 + # scalar. use conventional momentum. + stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["delta"] = stored_delta - is_weight = (p.ndim > 2) - min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] - p_flat = p.reshape(batch_size, numel) - abs_sum = p_flat.abs().sum(dim=1, keepdim=True) - min_abs_sum = min_scale * numel # if abs_sum is less than this we pad with an extra element. - abs_sum_clamped = abs_sum.clamp(min=min_abs_sum) - scale = (abs_sum_clamped / numel) # must be nonzero thanks to min_abs_sum - - # scaling_lr_scale is to control the learning-rate of scaling factors. - # log_scale controls the overall scale of this tensor - log_scale = (1 / group["scaling_lr_scale"]) * scale.log() - - ans = torch.cat((p_flat / scale, log_scale), dim=1) - return ans - -def reverse_transform_param(group, p, orig_shape): - batch_size = p.shape[0] - if p.numel() == batch_size: - return (p * group["scalar_lr_scale"]).reshape(*orig_shape) - # numel is num elements of each parameter tensor in the batch. - numel = p.shape[1] - 1 - - is_weight = (len(orig_shape) > 2) - max_scale = group["weight_max_scale"] if is_weight else group["bias_max_scale"] + stored_delta.add_(delta) + stored_delta.mul_(beta) + + delta = (-lr * (1 - beta)) * stored_delta + return delta + +def get_scale(group, p, grad): + is_weight = (p.ndim > 2) # is weight, not bias. for scalars, we do not + # reach here. 1st dim is batch-of-params dim. min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] - log_scale = (p[:, numel:] * group["scaling_lr_scale"]) - scaling_lr = group["scaling_lr_scale"] * group["lr"] + dims = tuple(range(1, p.ndim)) + abs_mean = p.abs().mean(dim=dims, keepdim=True) + abs_mean = abs_mean.clamp(min=min_scale) - # Apply weight-decay of log_scale, similar to weight decay of AdamW, except it regresses the - # log-scale to a default value instead of regressing the scale towards zero. - log_scale_default = math.log(group["scale_default"]) - log_scale = ((log_scale - log_scale_default) * (1. - group["scale_decay"] * scaling_lr)) + log_scale_default - scale = log_scale.exp().clamp(min=min_scale, max=max_scale) + scale = abs_mean - q = p[:, :-1] * scale - q = q.reshape(*orig_shape) - return q + log_scale_grad = (p * grad).sum(dim=dims, keepdim=True) + return scale, log_scale_grad -def forward_transform_param_and_grad(group, p, grad): - # returns new parameter. - p_shape = p.shape - p_flat = forward_transform_param(group, p).detach() - with torch.enable_grad(): - p_flat.requires_grad = True - p_reconstruct = reverse_transform_param(group, p_flat, p.shape) - p_reconstruct.backward(gradient=grad) - return p_flat.detach(), p_flat.grad def scaling_step(group, p, state, grad): # returns new parameter. p_shape = p.shape - p_flat, grad_flat = forward_transform_param_and_grad(group, p, grad) - p_flat += momentum_step(group, p_flat, state, grad_flat, p_shape) + scale, log_scale_grad = get_scale(group, p, grad) + + try: + scale_state = state["scale"] + except: + scale_state = dict() + state["scale"] = scale_state + scale_state["step"] = state["step"] + + scale_lr = group["lr"] * group["scalar_lr_scale"] + delta_log_scale = basic_momentum_step(group, scale_state, log_scale_grad, + lr=scale_lr, beta=0.9) + # the following is decay of the log scale towards a user-specified default value, like + # AdamW but on the log of the scale. + delta_log_scale = delta_log_scale - scale_lr * (scale.log() - math.log(group["scale_default"])) * group["scale_decay"] + + is_weight = (p.ndim > 2) + max_scale = group["weight_max_scale"] if is_weight else group["bias_max_scale"] + min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] + new_scale = (scale * (1. + delta_log_scale)).clamp(min=min_scale, max=max_scale) + + delta = momentum_step(group, state, grad) + + return p * (new_scale / scale) + delta * scale - p = reverse_transform_param(group, p_flat, p.shape) - return p def debug_step(group, p, state, grad): @@ -336,7 +313,10 @@ def debug_step(group, p, state, grad): debug_buffer_size = 256 step = state["step"] - p = scaling_step(group, p, state, grad) + if p.shape[0] == p.numel(): + p = p + basic_momentum_step(group, state, grad, lr=group["lr"]*group["scalar_lr_scale"], beta=0.9) + else: + p = scaling_step(group, p, state, grad) if debug_interval == 0 or step % debug_interval != 0: return p @@ -1602,19 +1582,6 @@ def _test_transformed_adam(hidden_dim: int): logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") -def _test_transform_params(): - # caution: this has occasional errors. - group = { "bias_min_scale": 0.001, "weight_min_scale": 0.01, "scalar_lr_scale": 0.1, "scaling_lr_scale": 0.5, - "scale_default": 0.05, "scale_decay": 0.01, - "weight_max_scale": 20.0, "bias_max_scale": 20.0, "lr": 0.0} # lr set to 0.0 so weight-scale decay does not happen. - for scale in [ 0.0, 1.0e-05, 0.001, 0.01, 1.0, 10.0 ]: - for shape in [ (1, 1), (2, 1), (2, 2), (2, 3, 4), (3, 10, 20), (4,) ]: - p = scale * torch.randn(*shape) - q = forward_transform_param(group, p) - r = reverse_transform_param(group, q, p.shape) - assert torch.allclose(p, r, atol=1.0e-02), (p, q, r) - - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) @@ -1632,7 +1599,6 @@ def _test_transform_params(): else: hidden_dim = 200 - _test_transform_params() _test_transformed_adam(hidden_dim) _test_eden() _test_sched3() From 16491637b0b9d5bfd643ee5dd3074f0693fc9e66 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 15:58:38 +0800 Subject: [PATCH 0718/1191] Fix a bug, double scalar_lr_scale and scaling_lr_scale --- egs/librispeech/ASR/zipformer/optim.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 372e4fab52..5d737e8806 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -167,13 +167,13 @@ def scale_by(x, beta1): else: xr = x.reshape(x.shape[0], -1, x.shape[-1]) scale_by(xr, beta1) - if not xr.storage() is x.storage(): + if not xr.untyped_storage() is x.untyped_storage(): x[:] = xr.reshape(*x.shape) return if x.shape[1] > x.shape[2]: xr = x.permute(0, 2, 1) scale_by(xr, beta1) - if not xr.storage() is x.storage(): + if not xr.untyped_storage() is x.untyped_storage(): x[:] = xr.permute(0, 2, 1) return @@ -185,7 +185,7 @@ def scale_by(x, beta1): if n % divisor == 0 and n // divisor <= max_dim: xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) scale_by(xr, beta1) - if not xr.storage() is x.storage(): + if not xr.untyped_storage() is x.untyped_storage(): x[:] = xr.reshape(*x.shape) return # if no divisor worked, just continue. @@ -290,12 +290,12 @@ def scaling_step(group, p, state, grad): state["scale"] = scale_state scale_state["step"] = state["step"] - scale_lr = group["lr"] * group["scalar_lr_scale"] + scale_lr = group["lr"] * group["scaling_lr_scale"] delta_log_scale = basic_momentum_step(group, scale_state, log_scale_grad, lr=scale_lr, beta=0.9) # the following is decay of the log scale towards a user-specified default value, like # AdamW but on the log of the scale. - delta_log_scale = delta_log_scale - scale_lr * (scale.log() - math.log(group["scale_default"])) * group["scale_decay"] + delta_log_scale = delta_log_scale - (scale_lr * group["scale_decay"]) * (scale.log() - math.log(group["scale_default"])) is_weight = (p.ndim > 2) max_scale = group["weight_max_scale"] if is_weight else group["bias_max_scale"] @@ -438,8 +438,8 @@ def __init__( beta2=0.98, scale_decay=0.01, scale_default=0.05, - scalar_lr_scale=0.1, - scaling_lr_scale=0.1, + scalar_lr_scale=0.2, + scaling_lr_scale=0.2, eps=1.0e-08, weight_min_scale=0.005, weight_max_scale=1.0, @@ -800,9 +800,7 @@ def _show_param_with_unusual_grad( for (p, state, batch_param_names) in tuples: dims = list(range(1, p.ndim)) - p_flat, grad_flat = forward_transform_param_and_grad(group, p, p.grad) - - grad_ratio = ((grad_flat ** 2).mean(dim=1) / state["exp_avg_sq"].mean(dim=1)).sqrt() + grad_ratio = ((p.grad ** 2).mean(dim=dims) / state["exp_avg_sq"].mean(dim=dims)).sqrt() ratios_names += zip(grad_ratio.to('cpu').tolist(), batch_param_names) ratios_names = sorted(ratios_names, reverse=True) @@ -922,7 +920,7 @@ def __init__( scale_decay=0.01, scale_default=0.05, scalar_lr_scale=0.1, - scaling_lr_scale=0.1, + scaling_lr_scale=0.2, eps=1.0e-08, weight_min_scale=0.005, weight_max_scale=1.0, From fad93b3b2c686aed03f61642c4802e4246036676 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 17:28:54 +0800 Subject: [PATCH 0719/1191] Version of optimizer where scale on x2 term is clamped to a maximum. --- egs/librispeech/ASR/zipformer/optim.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 5d737e8806..1d78622fb1 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -124,7 +124,7 @@ def batched_params(self, param_group, group_params_names): -def basic_step(group, state, grad): +def base_step(group, state, grad): # computes basic Adam normalized-grad using beta2 (dividing by gradient stddev) only. no momentum yet. beta2 = group["beta2"] eps = group["eps"] @@ -194,29 +194,31 @@ def scale_by(x, beta1): (batch_size, rows, cols) = x.shape # and rows <= cols x2 = torch.matmul(x, x.permute(0, 2, 1)) - # x2: (batch_size, rows, rows) + x3 = torch.matmul(x2, x) + eps = 1.0e-10 (batch_stride, stride1, stride2) = x2.stride() # x_squared_sum, equivalent to (x**2).sum(dim=(1, 2)), but faster to compute. x2_diag_sum = torch.as_strided(x2, (batch_size, rows), (batch_stride, stride1 + stride2)).sum(dim=1) # (batch_size,) - x2_sq_sum = (x2 ** 2).sum(dim=(1, 2)) # (batch_size,) scale = x2_diag_sum / x2_sq_sum - x_scaled = torch.matmul(x2, x) * scale[:, None, None] + alpha = (1. / 6) * (1 - beta1 ** 2) ** 2 + alpha = min(0.01, alpha) - #x_scaled_squared_sum = (x ** 2).sum(dim=(1, 2 + if False: + print(f"alpha={alpha}, scale={scale * (1-beta1)}") + dot_prod1 = (x * x).sum() + dot_prod2 = (x * x3).sum() * alpha + print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") - #if True: - # dot_prod1 = (x * x).sum(dim=(1, 2)) - # dot_prod2 = (x * x_scaled).sum(dim=(1, 2)) - # print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") + x.add_(x3 * (scale * (1-beta1)).clamp(max=alpha)[:, None, None], alpha=-1) - x.add_(x_scaled, alpha=(beta1-1)) # note: negative alpha. def momentum_step(group, state, grad): - delta = basic_step(group, state, grad) + delta = base_step(group, state, grad) + # delta is the normalized gradient; the rms of delta should be around 1. lr = group["lr"] step = state["step"] @@ -243,7 +245,7 @@ def momentum_step(group, state, grad): def basic_momentum_step(group, state, grad, lr, beta): - delta = basic_step(group, state, grad) + delta = base_step(group, state, grad) step = state["step"] try: From 24e7fb22566b4de1fd0ec11a678710302a1a020e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 18:49:56 +0800 Subject: [PATCH 0720/1191] Introduce a more principled way of computing the scaling factor on the x3 term, and the post-scaling. --- egs/librispeech/ASR/zipformer/optim.py | 62 ++++++++++++++++++++------ 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 1d78622fb1..6de0e4f3b1 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -196,15 +196,50 @@ def scale_by(x, beta1): x2 = torch.matmul(x, x.permute(0, 2, 1)) x3 = torch.matmul(x2, x) - eps = 1.0e-10 + # Suppose we set: x' = x - alpha x3 + # what is the alpha that minimizes the variance of the resulting x? (This is relevant + # because even this alpha may not be enough to decrease the variance by a factor of beta1^2.) + # (x - alpha x3)^2 = x^2 - 2 alpha x^4 + alpha^2 x6. + # alpha that minimizes this is x^2 / x^4 + + x6_sum = (x3 ** 2).sum(dim=(1, 2)) # equals numel * mean[x^6] + x4_sum = (x2 ** 2).sum(dim=(1, 2)) # equals numel * E[x^4] (batch_stride, stride1, stride2) = x2.stride() - # x_squared_sum, equivalent to (x**2).sum(dim=(1, 2)), but faster to compute. - x2_diag_sum = torch.as_strided(x2, (batch_size, rows), (batch_stride, stride1 + stride2)).sum(dim=1) # (batch_size,) - x2_sq_sum = (x2 ** 2).sum(dim=(1, 2)) # (batch_size,) - scale = x2_diag_sum / x2_sq_sum + x2_sum = torch.as_strided(x2, (batch_size, rows), (batch_stride, stride1 + stride2)).sum(dim=1) # (batch_size,) - alpha = (1. / 6) * (1 - beta1 ** 2) ** 2 - alpha = min(0.01, alpha) + + + eps = 1.0e-30 + + # we want the orig var (x^2) to be scaled by beta1^2 after the update. x2,x4,x6 below are all sums or + # means: x2_sum, x4_sum, x6_sum in the code. + # beta1^2 x2 = x^2 - 2 alpha x^4 + alpha^2 x^6 + # 0 = (1 - beta1^2) x^2 - 2 alpha x^4 + alpha^2 x^6. + # this is a quadratic equation in alpha: a alpha^2 + b alpha + c = 0, with: + # a = x^6 + # b = -2 x^4 + # c = (1 - beta1^2) x^2 + # and we want the smaller of the two solutions in alpha, which is a more minimal change to the params, less overshoot, so: + # alpha = (-b - sqrt(b^2 - 4ac)) / 2 a + # = (2 x^4 - sqrt( (4 * x^4)^2 - 4 * ((1-beta^2) x^2 x^6)) / (2 x^6.) + # = (x^4 - sqrt( (x^4)^2 - ((1-beta^2) x^2 x^6)) / x^6. + # below, clamping the term before the sqrt means that if the equation is not solvable we'll just + # take the maximum variance reduction we can, given by x4 / x6, and we'll later do conventional + # shrinkage (scaling by a number less than one) to get the required variance reduction. + + beta1_2 = beta1 ** 2 + + alpha = (x4_sum - (x4_sum**2 - (1 - beta1_2) * x2_sum * x6_sum).clamp(min=0).sqrt()) / (x6_sum + eps) + + # target_ratio is the ratio between the variance we want, to the variance we got + # with this alpha value. it + target_ratio = (beta1_2 * x2_sum) / (x2_sum - 2 * alpha * x4_sum + alpha**2 * x6_sum) + + post_scale = target_ratio ** 0.5 # post-scaling on x, after applying alpha. + + x.add_(x3 * alpha[:, None, None], alpha=-1) + + x *= post_scale[:, None, None] if False: print(f"alpha={alpha}, scale={scale * (1-beta1)}") @@ -212,7 +247,6 @@ def scale_by(x, beta1): dot_prod2 = (x * x3).sum() * alpha print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") - x.add_(x3 * (scale * (1-beta1)).clamp(max=alpha)[:, None, None], alpha=-1) @@ -235,12 +269,12 @@ def momentum_step(group, state, grad): stored_delta.add_(delta) - if step % 3 == 0: - # every third step, just do a normal decay, this is an efficient way of - # doing a kind of interpolation with the fourth-power regularization. - stored_delta.mul_(beta1) - else: - scale_by(stored_delta, beta1) + #if step % 3 == 0: + # # every third step, just do a normal decay, this is an efficient way of + # # doing a kind of interpolation with the fourth-power regularization. + # stored_delta.mul_(beta1) + #else: + scale_by(stored_delta, beta1) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) From 745ed82d5ef18e3ef379a2b95715449bce2d3192 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 18:56:04 +0800 Subject: [PATCH 0721/1191] Remove commented code --- egs/librispeech/ASR/zipformer/optim.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6de0e4f3b1..37185a5be6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -269,11 +269,6 @@ def momentum_step(group, state, grad): stored_delta.add_(delta) - #if step % 3 == 0: - # # every third step, just do a normal decay, this is an efficient way of - # # doing a kind of interpolation with the fourth-power regularization. - # stored_delta.mul_(beta1) - #else: scale_by(stored_delta, beta1) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) From b01e41cd83846267ade6bc35f2369015d1e8e8cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Nov 2025 23:03:23 +0800 Subject: [PATCH 0722/1191] Swap test to use sched3, is about the same. --- egs/librispeech/ASR/zipformer/optim.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 37185a5be6..ecaa1b5200 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -249,7 +249,6 @@ def scale_by(x, beta1): - def momentum_step(group, state, grad): delta = base_step(group, state, grad) # delta is the normalized gradient; the rms of delta should be around 1. @@ -1559,12 +1558,11 @@ def _test_transformed_adam(hidden_dim: int): else: assert "unknown test", test - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + scheduler = Sched3(optim, lr_batches=120, power=0.9, verbose=False) start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): - scheduler.step_epoch() # if epoch == 100 and test in [2,3]: # optim.reset_speedup() # check it doesn't crash. @@ -1575,6 +1573,7 @@ def _test_transformed_adam(hidden_dim: int): # diagnostic = diagnostics.attach_diagnostics(m, opts) for n, (x, y) in enumerate(train_pairs): + scheduler.step_batch() y_out = m(x) loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: @@ -1598,7 +1597,6 @@ def _test_transformed_adam(hidden_dim: int): loss.log().backward() optim.step() optim.zero_grad() - scheduler.step_batch() # diagnostic.print_diagnostics() From 664e77af8bc600e3489c7dfd82ad6e535f6e7847 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 10:08:25 +0800 Subject: [PATCH 0723/1191] Fix to printing unusual grads. --- egs/librispeech/ASR/zipformer/optim.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index ecaa1b5200..8391d078d9 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -828,9 +828,10 @@ def _show_param_with_unusual_grad( largest_name = "" ratios_names = [ ] for (p, state, batch_param_names) in tuples: - dims = list(range(1, p.ndim)) + def mean(x): + return x.mean(dim=tuple(range(1, x.ndim))) if x.ndim > 1 else x - grad_ratio = ((p.grad ** 2).mean(dim=dims) / state["exp_avg_sq"].mean(dim=dims)).sqrt() + grad_ratio = (mean(p.grad ** 2) / mean(state["exp_avg_sq"])).sqrt() ratios_names += zip(grad_ratio.to('cpu').tolist(), batch_param_names) ratios_names = sorted(ratios_names, reverse=True) From 10fb6f1be64109d247b28bccf136969b71d2c66c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 14:02:37 +0800 Subject: [PATCH 0724/1191] Replace convolution in convolution module with something based on hilbert transform. --- egs/librispeech/ASR/zipformer/zipformer.py | 172 ++++----------------- 1 file changed, 31 insertions(+), 141 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4d3dba83a0..86f3a32874 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -555,7 +555,6 @@ def __init__( def forward( self, src: Tensor, - weight_proj: Tensor, pos_emb: Tensor, chunk_size: int = -1, attn_mask: Optional[Tensor] = None, @@ -566,7 +565,6 @@ def forward( Pass the input through the encoder layer. Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - weight_proj: to be passed to the convolution modules, of shape (max_conv_length, conv_params) pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), @@ -597,7 +595,7 @@ def forward( src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module(src, weight_proj, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -773,13 +771,6 @@ def __init__( dim=0)) - conv_params = encoder_layer.conv_module.depthwise_conv.weight.shape[1] - max_conv_length = 255 - self.weight_proj = nn.Parameter(torch.randn(max_conv_length, conv_params)) - # scale weight_proj with a scale that's smaller for 'further-away-from-the-center' positions, since these positions - # will tend to have smaller weights. - self.register_buffer('weight_proj_scale', (1. / (2. + (torch.arange(max_conv_length) - (max_conv_length // 2)).abs())).unsqueeze(-1)) - self.copy_bypass = Identity() @@ -822,12 +813,10 @@ def forward( min=-1.0, max=-0.5) src_with_bypass = residual_scale * src - weight_proj = self.weight_proj * self.weight_proj_scale for i, mod in enumerate(self.layers): src = mod( src, - weight_proj, pos_emb, chunk_size=chunk_size, attn_mask=attn_mask, @@ -1630,138 +1619,23 @@ def __init__(self, -class ProjDepthwiseConv(nn.Module): - def __init__(self, - num_channels: int, - params_per_channel: int, - bias: bool = True): +class Hilbert(nn.Module): + def __init__(self): super().__init__() - # initialize to identity function. - self.weight = nn.Parameter((params_per_channel ** -0.5) * torch.randn(num_channels, params_per_channel)) - if bias: - self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) - else: - self.bias = None def forward(self, - x: Tensor, - weight_proj: Tensor) -> Tensor: - return self.forward_fft(x, weight_proj) - #a = self.forward_fft(x, weight_proj) - #b = self.forward_conv(x, weight_proj) - #diff = a - b - #def rms(x): - # return (x**2).mean().sqrt() - #print(f"size={x.shape}, rms(a)={rms(a)}, rms(b)={rms(b)}, rms(diff)={rms(diff)}, rms(diff-last)={rms(diff[-1])}, rms(diff-first)={rms(diff[0])}") - #return a - - - def forward_conv(self, - x: Tensor, - weight_proj: Tensor) -> Tensor: + x: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape - (_num_channels, params_per_channel) = self.weight.shape - - weight_proj = weight_proj.t() - # weight_proj: (params_per_channel, conv_length) - assert weight_proj.shape[0] == params_per_channel - conv_length = weight_proj.shape[1] - assert conv_length % 2 == 1 - - # if convolution length is longer than seq_len, we can truncate the convolution by - # wrapping it around (so it will be the same as if we did the full convolution - # with circular padding) - - if conv_length > seq_len: - wrapped_conv_length = seq_len - - # 'multiple' is the number of 'wraps' we sum over. this must be odd so - # that the original middle ends up in the middle after wrapping. - multiple = (conv_length + seq_len - 1) // seq_len - if multiple % 2 == 0: - multiple = multiple + 1 # need multiple to be odd. - padding = (seq_len * multiple) - conv_length - left_pad = padding // 2 - right_pad = padding - left_pad - weight_proj = torch.nn.functional.pad(weight_proj, (left_pad, right_pad)) - - weight_proj = weight_proj.reshape(params_per_channel, multiple, seq_len).sum(dim=1) - # weight_proj: (num_channels, seq_len) - if seq_len % 2 == 0: - # even-length convolution will cause efficiency problems for conv1d, so we pad - # the convolution with a zero on the left (which would have been the side that - # was made shorter by the uneven padding). The fact that it's zero won't matter - # because we'll just get the value from the wrapped around other side, due to - # circular padding. - weight_proj = torch.cat((torch.zeros(params_per_channel, 1, device=weight_proj.device, dtype=weight_proj.dtype), - weight_proj), dim=1) - conv_length = weight_proj.shape[1] - - - weight = torch.matmul(self.weight, weight_proj) - # weight: (num_channels, conv_length) ; note, conv_length may have been reduced to seq_len + 1 already. - padding = conv_length // 2 # note, conv_length will be odd. - - # weight: (num_channels, conv_width); conv_width is odd. - - x = x.permute(1, 2, 0) # (batch, channels, width) - weight = weight.unsqueeze(1) # (num_channels, 1, conv_width) - - x = torch.nn.functional.pad(x, (padding, padding), mode='circular') - x = torch.nn.functional.conv1d(x, weight, self.bias, groups=num_channels) - x = x.permute(2, 0, 1) # (seq, batch, channels) - return x - - - def forward_fft(self, - x: Tensor, - weight_proj: Tensor) -> Tensor: - (seq_len, batch_size, num_channels) = x.shape - (_num_channels, params_per_channel) = self.weight.shape - - weight_proj = weight_proj.t() - # weight_proj: (params_per_channel, conv_length) - assert weight_proj.shape[0] == params_per_channel - conv_length = weight_proj.shape[1] - assert conv_length % 2 == 1 - - # if convolution length is longer than seq_len, we can truncate the convolution by - # wrapping it around (so it will be the same as if we did the full convolution - # with circular padding) - - - middle = conv_length // 2 - # pad the convolution so that its middle point is positioned at an exact multiple of seq_len, - # which will become position zero after circular summing; and so that the total length is an - # exact multiple of seq_len. - left_pad = (-middle) % seq_len # caution if you translate this into C, this relies on python's definition. - right_pad = (-(conv_length + left_pad)) % seq_len - weight_proj = torch.nn.functional.pad(weight_proj, (left_pad, right_pad)) - - weight_proj = weight_proj.reshape(params_per_channel, -1, seq_len).sum(dim=1) - # weight_proj: (num_channels, seq_len). Central point of conv is positioned - # at position zero. - - weight = torch.matmul(self.weight, weight_proj) - # weight: (num_channels, seq_len). - - - x = x.permute(1, 2, 0) # (batch_size, num_channels, seq_len) - - both = torch.cat((x, weight.unsqueeze(0)), dim=0) - # both: (batch_size + 1, num_channels, seq_len) with torch.amp.autocast('cuda', enabled=False): # do it in float32 because non power of two seq_len is not supported in half precision. - both = torch.fft.rfft(both.to(torch.float32)) + x = torch.fft.rfft(x.to(torch.float32), dim=0) - # multiplication in fourier space is the same as (circular) convolution. - x = both[:-1] * both[-1].conj() + x = x * 1j - x = torch.fft.irfft(x, n=seq_len) + x = torch.fft.irfft(x, n=seq_len, dim=0) - x = x.permute(2, 0, 1) + self.bias # (seq, batch, channels) return x @@ -1802,21 +1676,24 @@ def __init__( self.activation2 = Identity() # for diagnostics - self.depthwise_conv = ProjDepthwiseConv(bottleneck_dim, - kernel_size) + self.hilbert = Hilbert() # hilbert transform - self.out_proj = ActivationDropoutAndLinear( + # phase change + self.phase_shift = nn.Parameter(torch.randn(bottleneck_dim)) + + # bias that determines gain. + self.bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) + + + self.out_proj = ScaledLinear( bottleneck_dim, channels, - activation="SwashR", - dropout_p=0.0, initial_scale=0.05, ) def forward( self, x: Tensor, - weight_proj: Tensor, src_key_padding_mask: Optional[Tensor] = None, chunk_size: int = -1, aux_loss_scale: float = 0.0, @@ -1825,7 +1702,6 @@ def forward( Args: x: Input tensor (#time, batch, channels). - weight_proj: tensor of shape (max_conv_length, kernel_size), with max_conv_length > kernel_size; expands the size of the convolution. src_key_padding_mask: the mask for the src keys per batch (optional): (batch, #time), contains True in masked positions. @@ -1845,7 +1721,21 @@ def forward( if src_key_padding_mask is not None: x = self.repeat_in_padding(x, src_key_padding_mask) - x = self.depthwise_conv(x, weight_proj) + + x = torch.complex(x, self.hilbert(x).to(x.dtype)) + + + x_abs = x.abs() + + # change the phase of x by multiplying by a phase shift. + phase_shift = self.phase_shift + x = x * torch.polar(torch.ones_like(phase_shift), phase_shift) + + eps = 1.0e-05 + x_scale = (x_abs + self.bias) / (x_abs + eps) + x_scale = x_scale.clamp(max=4.0) # this is to limit gain which should lead to gradient blowup. + + x = x_scale * torch.real(x) x = self.out_proj(x) # (time, batch, channels) From efd57e73f903640c3966688f8c82fcbdf5f58918 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 14:44:29 +0800 Subject: [PATCH 0725/1191] Have the frequency shift be frequency specific. --- egs/librispeech/ASR/zipformer/zipformer.py | 54 +++++++++------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 86f3a32874..89b7dd600e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1604,37 +1604,39 @@ def round_up_to_power_of_two(x): return x -class ProjDepthwiseConv(nn.Module): - def __init__(self, - num_channels: int, - params_per_channel: int, - bias: bool = True): - super().__init__() - # initialize to identity function. - self.weight = nn.Parameter((params_per_channel ** -0.5) * torch.randn(num_channels, params_per_channel)) - if bias: - self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) - else: - self.bias = None - -class Hilbert(nn.Module): - def __init__(self): +class PhaseShift(nn.Module): + def __init__(self, + num_channels: int, + params_per_channel: int): super().__init__() - + self.phase_shift = nn.Parameter(torch.randn(num_channels, params_per_channel)) def forward(self, x: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape + + phase_shift = self.phase_shift.unsqueeze(0) # (1, num_channels, params_per_channel) + with torch.amp.autocast('cuda', enabled=False): # do it in float32 because non power of two seq_len is not supported in half precision. x = torch.fft.rfft(x.to(torch.float32), dim=0) - x = x * 1j + # x: (seq_len, batch_size, num_channels) - x = torch.fft.irfft(x, n=seq_len, dim=0) + N = x.shape[0] # num freqs + phase_shift = torch.nn.functional.interpolate(phase_shift, N, mode='linear', align_corners=True) + # phase_shift: (1, num_channels, num_freq) + phase_shift = phase_shift.permute(2, 0, 1) # (num_freq, 1, num_channels) + + x = x * torch.polar(torch.ones_like(phase_shift), phase_shift) + + x_real = torch.fft.irfft(x, n=seq_len, dim=0) + x_im = torch.fft.irfft(x * 1j, n=seq_len, dim=0) + + x = torch.complex(x_real, x_im) return x @@ -1675,11 +1677,7 @@ def __init__( self.activation2 = Identity() # for diagnostics - - self.hilbert = Hilbert() # hilbert transform - - # phase change - self.phase_shift = nn.Parameter(torch.randn(bottleneck_dim)) + self.phase_shift = PhaseShift(bottleneck_dim, kernel_size) # computes analytic signal with frequency specific phase shift # bias that determines gain. self.bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) @@ -1721,16 +1719,8 @@ def forward( if src_key_padding_mask is not None: x = self.repeat_in_padding(x, src_key_padding_mask) - - x = torch.complex(x, self.hilbert(x).to(x.dtype)) - - + x = self.phase_shift(x) x_abs = x.abs() - - # change the phase of x by multiplying by a phase shift. - phase_shift = self.phase_shift - x = x * torch.polar(torch.ones_like(phase_shift), phase_shift) - eps = 1.0e-05 x_scale = (x_abs + self.bias) / (x_abs + eps) x_scale = x_scale.clamp(max=4.0) # this is to limit gain which should lead to gradient blowup. From 20b56fb975771cbc12d52b544858edf2436afb5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 14:59:15 +0800 Subject: [PATCH 0726/1191] Reduce layers from 6,9,26,9 to 6,8,22,8. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 946487c569..aff6eff748 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,9,26,9", + default="6,8,22,8", help="Number of zipformer encoder layers per stack, comma separated.", ) From bb5fe93e6b5a33741a70872a23336992a4f62719 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 15:15:58 +0800 Subject: [PATCH 0727/1191] Decrease scale_max from 4.0 to 2.0. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 89b7dd600e..8aa6490758 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1723,7 +1723,10 @@ def forward( x_abs = x.abs() eps = 1.0e-05 x_scale = (x_abs + self.bias) / (x_abs + eps) - x_scale = x_scale.clamp(max=4.0) # this is to limit gain which should lead to gradient blowup. + scale_max = 2.0 + # make this strictly more than 1.0, but not too large, as it is the maximum amount by which this + # operation can blow up the gradient. + x_scale = x_scale.clamp(max=scale_max) x = x_scale * torch.real(x) From bb4db3977065282ea13d643fde0c890a015687db Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 15:34:54 +0800 Subject: [PATCH 0728/1191] Add SwashR nonlinearity at output of conv module. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8aa6490758..3519ececad 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1682,10 +1682,11 @@ def __init__( # bias that determines gain. self.bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) - - self.out_proj = ScaledLinear( + self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, channels, + activation="SwashR", + dropout_p=0.0, initial_scale=0.05, ) From 89e849380b051afb9fbc6b962a7d80bf17673e6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 17:15:07 +0800 Subject: [PATCH 0729/1191] Have weights, not just phases. Do sqrt to compress magnitudes. --- egs/librispeech/ASR/zipformer/zipformer.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3519ececad..57d0e20109 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1611,15 +1611,16 @@ def __init__(self, num_channels: int, params_per_channel: int): super().__init__() - self.phase_shift = nn.Parameter(torch.randn(num_channels, params_per_channel)) + self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) + # the factor of 2 is for (sin, cos) + self.weight_proj = nn.Linear(params_per_channel, 2 * params_per_channel) + def forward(self, x: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape - phase_shift = self.phase_shift.unsqueeze(0) # (1, num_channels, params_per_channel) - with torch.amp.autocast('cuda', enabled=False): # do it in float32 because non power of two seq_len is not supported in half precision. x = torch.fft.rfft(x.to(torch.float32), dim=0) @@ -1627,11 +1628,18 @@ def forward(self, # x: (seq_len, batch_size, num_channels) N = x.shape[0] # num freqs - phase_shift = torch.nn.functional.interpolate(phase_shift, N, mode='linear', align_corners=True) - # phase_shift: (1, num_channels, num_freq) - phase_shift = phase_shift.permute(2, 0, 1) # (num_freq, 1, num_channels) - x = x * torch.polar(torch.ones_like(phase_shift), phase_shift) + + weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) + weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) + weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) + # weight: (N, num_channels) + weight = weight.unsqueeze(1) # (N, 1, num_channels) + # the following should be tested. it's to make the magnitudes of the weights closer to 1. + eps = 1.0e-05 + weight = weight / (weight.abs() + eps).sqrt() + + x = x * weight x_real = torch.fft.irfft(x, n=seq_len, dim=0) x_im = torch.fft.irfft(x * 1j, n=seq_len, dim=0) From aae1b0e63a6b2d775896196e8fe3eb3b73cdb955 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 17:29:32 +0800 Subject: [PATCH 0730/1191] Every 4 batches, scale the momentum stats in the normal way. --- egs/librispeech/ASR/zipformer/optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8391d078d9..862e2fe4fd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -268,7 +268,10 @@ def momentum_step(group, state, grad): stored_delta.add_(delta) - scale_by(stored_delta, beta1) + if step % 4 == 0: + stored_delta.mul_(beta1) + else: + scale_by(stored_delta, beta1) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) From 1e0d700ef52321d00346dec37678fa6c6038fb26 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 17:29:32 +0800 Subject: [PATCH 0731/1191] Every 4 batches, scale the momentum stats in the normal way. --- egs/librispeech/ASR/zipformer/optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8391d078d9..862e2fe4fd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -268,7 +268,10 @@ def momentum_step(group, state, grad): stored_delta.add_(delta) - scale_by(stored_delta, beta1) + if step % 4 == 0: + stored_delta.mul_(beta1) + else: + scale_by(stored_delta, beta1) return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) From e663215ecf3a9a124699342f866422adb45bade2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 20:19:07 +0800 Subject: [PATCH 0732/1191] Remove SwashR nonlinearity and have paramerterized nonlinearity on the complex output. --- egs/librispeech/ASR/zipformer/zipformer.py | 26 +++++++++------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 57d0e20109..de4f64a5af 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1578,13 +1578,11 @@ def __init__(self, embed_dim: int, feedforward_dim: int): self.in_proj.weight_min_rms = 0.02 # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( + self.out_proj = ScaledLinear( feedforward_dim, embed_dim, - activation="SwashL", - dropout_p=0.0, - bias=True, initial_scale=0.5, + bias=True, ) @@ -1687,8 +1685,10 @@ def __init__( self.phase_shift = PhaseShift(bottleneck_dim, kernel_size) # computes analytic signal with frequency specific phase shift - # bias that determines gain. - self.bias = nn.Parameter(0.01 * torch.randn(bottleneck_dim)) + # have a small num centers due to concerns over memory. + num_centers = 2 + self.centers = nn.Parameter(torch.randn(num_centers, bottleneck_dim, 2)) # real, im. + self.center_weights = nn.Parameter(0.1 * torch.randn(num_centers, bottleneck_dim)) self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1728,16 +1728,12 @@ def forward( if src_key_padding_mask is not None: x = self.repeat_in_padding(x, src_key_padding_mask) - x = self.phase_shift(x) - x_abs = x.abs() - eps = 1.0e-05 - x_scale = (x_abs + self.bias) / (x_abs + eps) - scale_max = 2.0 - # make this strictly more than 1.0, but not too large, as it is the maximum amount by which this - # operation can blow up the gradient. - x_scale = x_scale.clamp(max=scale_max) + x = self.phase_shift(x) # x (complex): (time, batch, bottleneck_dim) + + centers = torch.view_as_complex(self.centers) # (num_centers, bottleneck_dim) + center_weights = self.center_weights # (num_centers, bottleneck_dim) - x = x_scale * torch.real(x) + x = x.real + (center_weights * (x.unsqueeze(2) - centers).abs()).sum(dim=2) x = self.out_proj(x) # (time, batch, channels) From 601ff49664be8d50c14d1423c0ea38248cc87cc7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 20:38:15 +0800 Subject: [PATCH 0733/1191] Have more centers and compute center nonlin with checkpointing. --- egs/librispeech/ASR/zipformer/zipformer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index de4f64a5af..6d799a3b28 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1647,6 +1647,13 @@ def forward(self, return x +def compute_complex_nonlin(x: Tensor, centers: Tensor, center_weights: Tensor): + # x: complex, (time, batch, bottleneck_dim) + # centers: complex, (num_centers, bottleneck_dim) + # centers: comlex, (num_centers, bottleneck_dim) + return x.real + (center_weights * (x.unsqueeze(2) - centers).abs()).sum(dim=2) + + class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer2 model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py @@ -1686,7 +1693,7 @@ def __init__( self.phase_shift = PhaseShift(bottleneck_dim, kernel_size) # computes analytic signal with frequency specific phase shift # have a small num centers due to concerns over memory. - num_centers = 2 + num_centers = 6 self.centers = nn.Parameter(torch.randn(num_centers, bottleneck_dim, 2)) # real, im. self.center_weights = nn.Parameter(0.1 * torch.randn(num_centers, bottleneck_dim)) @@ -1730,10 +1737,8 @@ def forward( x = self.phase_shift(x) # x (complex): (time, batch, bottleneck_dim) - centers = torch.view_as_complex(self.centers) # (num_centers, bottleneck_dim) - center_weights = self.center_weights # (num_centers, bottleneck_dim) - - x = x.real + (center_weights * (x.unsqueeze(2) - centers).abs()).sum(dim=2) + x = torch.utils.checkpoint.checkpoint(compute_complex_nonlin, x, centers, center_weights, + use_reentrant=False) x = self.out_proj(x) # (time, batch, channels) From 60bd232e1524e4c6bb5944bb40b6fd262d92ac58 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 20:45:25 +0800 Subject: [PATCH 0734/1191] More memory efficient computation and have more centers. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6d799a3b28..9a65e566ba 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1737,7 +1737,10 @@ def forward( x = self.phase_shift(x) # x (complex): (time, batch, bottleneck_dim) - x = torch.utils.checkpoint.checkpoint(compute_complex_nonlin, x, centers, center_weights, + centers = torch.view_as_complex(self.centers) # (num_centers, bottleneck_dim) + center_weights = self.center_weights # (num_centers, bottleneck_dim) + x = torch.utils.checkpoint.checkpoint(compute_complex_nonlin, x, + centers, center_weights, use_reentrant=False) x = self.out_proj(x) # (time, batch, channels) From 0a657bbcd2f57f5371846230aa42ea1b6018ddf2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 21:59:47 +0800 Subject: [PATCH 0735/1191] Restore SwashR nonlinearity --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9a65e566ba..9d91c93778 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1578,9 +1578,11 @@ def __init__(self, embed_dim: int, feedforward_dim: int): self.in_proj.weight_min_rms = 0.02 # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ScaledLinear( + self.out_proj = ActivationDropoutAndLinear( feedforward_dim, embed_dim, + dropout_p=0.0, + activation="SwashR", initial_scale=0.5, bias=True, ) From f437d9b9686312b74cc31328e123c94084243d1f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 18 Nov 2025 23:40:01 +0800 Subject: [PATCH 0736/1191] Do something different with centers, shrink towards a point, then take the real part. --- egs/librispeech/ASR/zipformer/zipformer.py | 27 +++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9d91c93778..dff2ff098d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1649,11 +1649,24 @@ def forward(self, return x -def compute_complex_nonlin(x: Tensor, centers: Tensor, center_weights: Tensor): +def compute_complex_nonlin(x: Tensor, centers: Tensor, biases: Tensor): # x: complex, (time, batch, bottleneck_dim) # centers: complex, (num_centers, bottleneck_dim) - # centers: comlex, (num_centers, bottleneck_dim) - return x.real + (center_weights * (x.unsqueeze(2) - centers).abs()).sum(dim=2) + # biases: comlex, (num_centers, bottleneck_dim) + num_centers = centers.shape[0] + biases = - biases.abs() # make all the biases negative + for i in range(num_centers): + c = centers[i] + b = biases[i] + x = x - c + eps = 1.0e-05 + x_abs = x.abs() + # shrink towards this central point. + scale = (x_abs + b).relu() / (x_abs + eps) + x = x * scale + x = x + c + return x.real + class ConvolutionModule(nn.Module): @@ -1694,10 +1707,9 @@ def __init__( self.phase_shift = PhaseShift(bottleneck_dim, kernel_size) # computes analytic signal with frequency specific phase shift - # have a small num centers due to concerns over memory. - num_centers = 6 + num_centers = 2 self.centers = nn.Parameter(torch.randn(num_centers, bottleneck_dim, 2)) # real, im. - self.center_weights = nn.Parameter(0.1 * torch.randn(num_centers, bottleneck_dim)) + self.biases = nn.Parameter(0.1 * torch.randn(num_centers, bottleneck_dim)) self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1740,9 +1752,8 @@ def forward( x = self.phase_shift(x) # x (complex): (time, batch, bottleneck_dim) centers = torch.view_as_complex(self.centers) # (num_centers, bottleneck_dim) - center_weights = self.center_weights # (num_centers, bottleneck_dim) x = torch.utils.checkpoint.checkpoint(compute_complex_nonlin, x, - centers, center_weights, + centers, self.biases, use_reentrant=False) x = self.out_proj(x) # (time, batch, channels) From 7c353e8dc3feef4ac679be2f5b9cd14152cc3cde Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Nov 2025 12:06:25 +0800 Subject: [PATCH 0737/1191] Reduce num-layers from 6,8,22,8 to 5,7,20,9 and increase value-head-dim from 48 to 64. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 666f767fa1..f80e38e35d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,22,8", + default="6,7,20,9", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -241,7 +241,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--value-head-dim", type=str, - default="48", + default="64", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From e74fa5ff47d7585a264557b6ef60aa771d557700 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Nov 2025 12:07:08 +0800 Subject: [PATCH 0738/1191] Reduce num-layers from 6,8,22,8 to 5,7,20,9 and increase value-head-dim from 48 to 64. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index f80e38e35d..fe90b8d166 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,7,20,9", + default="5,7,20,9", help="Number of zipformer encoder layers per stack, comma separated.", ) From 171b801fc58d96370c51daf1f6146f898aac9700 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Nov 2025 12:57:26 +0800 Subject: [PATCH 0739/1191] Simplify convolution module to use more basic fft-based convolution. --- egs/librispeech/ASR/zipformer/zipformer.py | 70 +++++++--------------- 1 file changed, 22 insertions(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index dff2ff098d..7e2a311621 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1606,15 +1606,19 @@ def round_up_to_power_of_two(x): -class PhaseShift(nn.Module): +class FftConv(nn.Module): def __init__(self, num_channels: int, - params_per_channel: int): + params_per_channel: int, + bias: bool = True): super().__init__() self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) # the factor of 2 is for (sin, cos) self.weight_proj = nn.Linear(params_per_channel, 2 * params_per_channel) + if bias: + self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) + def forward(self, x: Tensor) -> Tensor: @@ -1624,49 +1628,24 @@ def forward(self, with torch.amp.autocast('cuda', enabled=False): # do it in float32 because non power of two seq_len is not supported in half precision. x = torch.fft.rfft(x.to(torch.float32), dim=0) - - # x: (seq_len, batch_size, num_channels) - + # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs - - weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) # weight: (N, num_channels) weight = weight.unsqueeze(1) # (N, 1, num_channels) - # the following should be tested. it's to make the magnitudes of the weights closer to 1. - eps = 1.0e-05 - weight = weight / (weight.abs() + eps).sqrt() - x = x * weight + x = torch.fft.irfft(x, n=seq_len, dim=0) - x_real = torch.fft.irfft(x, n=seq_len, dim=0) - x_im = torch.fft.irfft(x * 1j, n=seq_len, dim=0) - - x = torch.complex(x_real, x_im) + try: + x = x + self.bias + except AttributeError: + pass return x -def compute_complex_nonlin(x: Tensor, centers: Tensor, biases: Tensor): - # x: complex, (time, batch, bottleneck_dim) - # centers: complex, (num_centers, bottleneck_dim) - # biases: comlex, (num_centers, bottleneck_dim) - num_centers = centers.shape[0] - biases = - biases.abs() # make all the biases negative - for i in range(num_centers): - c = centers[i] - b = biases[i] - x = x - c - eps = 1.0e-05 - x_abs = x.abs() - # shrink towards this central point. - scale = (x_abs + b).relu() / (x_abs + eps) - x = x * scale - x = x + c - return x.real - class ConvolutionModule(nn.Module): @@ -1705,11 +1684,7 @@ def __init__( self.activation2 = Identity() # for diagnostics - self.phase_shift = PhaseShift(bottleneck_dim, kernel_size) # computes analytic signal with frequency specific phase shift - - num_centers = 2 - self.centers = nn.Parameter(torch.randn(num_centers, bottleneck_dim, 2)) # real, im. - self.biases = nn.Parameter(0.1 * torch.randn(num_centers, bottleneck_dim)) + self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1737,6 +1712,14 @@ def forward( Tensor: Output tensor (#time, batch, channels). """ + + # x: (time, batch, channels) + # Caution: this module is not completely + # invariant to the number of frames each sequence is padded with, since + # the FFT-based convolution treats the signal as repeating. + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + x = self.in_proj(x) # (time, batch, 2*channels) x, s = x.chunk(2, dim=2) @@ -1745,16 +1728,7 @@ def forward( x = x * s x = self.activation2(x) # identity - #x: (time, batch, channels) - if src_key_padding_mask is not None: - x = self.repeat_in_padding(x, src_key_padding_mask) - - x = self.phase_shift(x) # x (complex): (time, batch, bottleneck_dim) - - centers = torch.view_as_complex(self.centers) # (num_centers, bottleneck_dim) - x = torch.utils.checkpoint.checkpoint(compute_complex_nonlin, x, - centers, self.biases, - use_reentrant=False) + x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) x = self.out_proj(x) # (time, batch, channels) From 6bff31fdd15d507ed17f1525d8c67e0eb3d7106b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Nov 2025 10:54:42 +0800 Subject: [PATCH 0740/1191] Change lr-batches from 17.5k to 15k. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fe90b8d166..9b02cb6986 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -454,13 +454,13 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( "--lr-batches", type=float, - default=17500, + default=15000, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) From 44578a1ee19add463a69d22c29258ee2b80ec924 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Nov 2025 16:54:08 +0800 Subject: [PATCH 0741/1191] Decrease beta1 from 0.9995 to 0.995. --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 862e2fe4fd..d4e26e82ac 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -466,7 +466,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.9995, + beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, @@ -948,7 +948,7 @@ def __init__( params, lr=3e-02, clipping_scale=None, - beta1=0.999, + beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, scale_decay=0.01, From d3e2d779fbfc4ea9f7415185bf6794df5c00d1b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Nov 2025 11:51:45 +0800 Subject: [PATCH 0742/1191] Introduce additive bypass around the convolution module, as a kind of gating. Reduce num layers a bit to compensate increased params. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fe90b8d166..3ab071a9c6 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="5,7,20,9", + default="5,7,18,7", help="Number of zipformer encoder layers per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7e2a311621..d87f2bb397 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1672,7 +1672,7 @@ def __init__( self.in_proj = nn.Linear( channels, - 2 * bottleneck_dim, + 3 * bottleneck_dim, ) # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. @@ -1722,7 +1722,7 @@ def forward( x = self.in_proj(x) # (time, batch, 2*channels) - x, s = x.chunk(2, dim=2) + x, s, y = x.chunk(3, dim=2) s = self.sigmoid(s) x = self.activation1(x) # identity. x = x * s @@ -1730,7 +1730,7 @@ def forward( x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) - x = self.out_proj(x) # (time, batch, channels) + x = self.out_proj(x + y) # (time, batch, channels) return x From 0ae434d2b6e6d5ee723b2ac3b3ff989c74ef4ad3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 21 Nov 2025 17:20:01 +0800 Subject: [PATCH 0743/1191] Change weighting in depthwise conv to be multiplication by sigmoid. --- egs/librispeech/ASR/zipformer/zipformer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d87f2bb397..3115918969 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1680,7 +1680,9 @@ def __init__( self.activation1 = Identity() # for diagnostics - self.sigmoid = nn.Sigmoid() + self.sigmoid1 = nn.Sigmoid() + + self.sigmoid2 = nn.Sigmoid() self.activation2 = Identity() # for diagnostics @@ -1723,14 +1725,16 @@ def forward( x = self.in_proj(x) # (time, batch, 2*channels) x, s, y = x.chunk(3, dim=2) - s = self.sigmoid(s) + s = self.sigmoid1(s) + y = self.sigmoid2(y) x = self.activation1(x) # identity. x = x * s x = self.activation2(x) # identity x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) - x = self.out_proj(x + y) # (time, batch, channels) + x = x * y + x = self.out_proj(x) # (time, batch, channels) return x From e21aa859ece367b7f732f053e0215dbe9be22284 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Nov 2025 13:59:31 +0800 Subject: [PATCH 0744/1191] Reduce num layers from 5,7,20,9 to 5,7,20,7; increase num-heads from 3 to 4. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fe90b8d166..4a034a140a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -185,7 +185,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="5,7,20,9", + default="5,7,20,7", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -220,7 +220,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-heads", type=str, - default="3", + default="4", help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", ) From 9d7fac799a7115509cba98292284392dff4fa299 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Nov 2025 14:26:54 +0800 Subject: [PATCH 0745/1191] Increase query-head-dim from 32 to 64. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 4a034a140a..1f85b63e12 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -234,7 +234,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--query-head-dim", type=str, - default="32", + default="64", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) From f08851c5d8ca14d9e643f1a7b9ffed5f953395ee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Nov 2025 20:29:53 +0800 Subject: [PATCH 0746/1191] Merge branch 'deterministic_invertible1637conv' into deterministic_invertible1649conv From 70bd1ea59ebc125087b6015dbf9fb1c4c5673173 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Nov 2025 20:30:23 +0800 Subject: [PATCH 0747/1191] Revert unwanted change. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index bb74e06785..f81c4940d9 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -460,7 +460,7 @@ def get_parser(): parser.add_argument( "--lr-batches", type=float, - default=15000, + default=17500, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) From d1ba202315e5c8f8ccf92b7afd7ba0b640d91c3d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Nov 2025 12:14:37 +0800 Subject: [PATCH 0748/1191] Decrase query-head-dim and value-head-dim from 64 to 32. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index abeb950501..3b11590f57 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -234,14 +234,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--query-head-dim", type=str, - default="64", + default="32", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) parser.add_argument( "--value-head-dim", type=str, - default="64", + default="32", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From 82438b4f1cef7218e0744e11c6928cf44f4c6bc8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 24 Nov 2025 16:43:00 +0800 Subject: [PATCH 0749/1191] Save some temporary work --- .../ASR/zapformer/asr_datamodule.py | 145 ++++++++--- .../ASR/zapformer/multicopy_dataset.py | 243 ++++++++++++++++++ 2 files changed, 349 insertions(+), 39 deletions(-) create mode 100755 egs/librispeech/ASR/zapformer/multicopy_dataset.py diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 4db6e101fb..edf7c35ae5 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -32,13 +32,17 @@ PrecomputedFeatures, SimpleCutSampler, ) -# This K2SpeechRecognitionDataset is a modified version of one from +# MulticopyDataset is a modified version of one from # lhotse.dataset, modified to, in training mode, to return a batch that has 3 # different copies of the same data with the last two having different Musan # augmentations and the first having none; and also include the key "num_copies" # in the batch which would be 1 for the validation data (no Musan) and 3 for the # training data with musan. -from speech_recognition import K2SpeechRecognitionDataset +try: + from multicopy_dataset import MulticopyDataset # interface like K2SpeechRecognitionDataset +except: + pass + from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, OnTheFlyFeatures, @@ -58,6 +62,9 @@ def __call__(self, worker_id: int): class LibriSpeechAsrDataModule: + pass # only left here so other branches can run in the same directory. TODO: remove. + +class AsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -241,7 +248,7 @@ def train_dataloaders( ] + transforms logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( + train = MulticopyDataset( input_strategy=eval(self.args.input_strategy)(), cut_transforms=transforms, input_transforms=[], @@ -259,7 +266,7 @@ def train_dataloaders( # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( + train = MulticopyDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, @@ -317,13 +324,13 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( + validate = MulticopyDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: - validate = K2SpeechRecognitionDataset( + validate = MulticopyDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, ) @@ -345,7 +352,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( + test = MulticopyDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), @@ -365,90 +372,150 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: ) return test_dl - @lru_cache() + +class LibriSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - librispeech_cuts_dev-clean.jsonl.gz + - librispeech_cuts_dev-other.jsonl.gz + - librispeech_cuts_test-clean.jsonl.gz + - librispeech_cuts_test-other.jsonl.gz + - librispeech_cuts_train-clean-100.jsonl.gz + - librispeech_cuts_train-clean-360.jsonl.gz + - librispeech_cuts_train-other-500.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + def train_clean_5_cuts(self) -> CutSet: logging.info("mini_librispeech: About to get train-clean-5 cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" ) - @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" ) - @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" ) - @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" ) - @lru_cache() def train_all_shuf_cuts(self) -> CutSet: logging.info( "About to get the shuffled train-clean-100, \ train-clean-360 and train-other-500 cuts" ) return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" ) - @lru_cache() def dev_clean_2_cuts(self) -> CutSet: logging.info("mini_librispeech: About to get dev-clean-2 cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + self.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" ) - @lru_cache() def dev_clean_cuts(self) -> CutSet: logging.info("About to get dev-clean cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" ) - @lru_cache() def dev_other_cuts(self) -> CutSet: logging.info("About to get dev-other cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" ) - @lru_cache() def test_clean_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" ) - @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) - @lru_cache() - def gigaspeech_subset_small_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") - @lru_cache() - def gigaspeech_dev_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") +class GigaSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + - gigaspeech_cuts_L_raw.jsonl.gz + - gigaspeech_cuts_M_raw.jsonl.gz + - gigaspeech_cuts_S_raw.jsonl.gz + - gigaspeech_cuts_XS_raw.jsonl.gz + - gigaspeech_cuts_DEV_raw.jsonl.gz + - gigaspeech_cuts_TEST_raw.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_XL_cuts(self) -> CutSet: + logging.info("About to get train-XL cuts") + + filenames = list( + glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" # noqa + ) + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading {len(sorted_filenames)} splits") + + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_M_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_M_raw.jsonl.gz" + logging.info(f"About to get train-M cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_S_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_S_raw.jsonl.gz" + logging.info(f"About to get train-S cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XS_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XS_raw.jsonl.gz" + logging.info(f"About to get train-XS cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def test_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + logging.info(f"About to get TEST cuts from {f}") + return load_manifest_lazy(f) - @lru_cache() - def gigaspeech_test_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + def dev_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + logging.info(f"About to get DEV cuts from {f}") + return load_manifest_lazy(f) diff --git a/egs/librispeech/ASR/zapformer/multicopy_dataset.py b/egs/librispeech/ASR/zapformer/multicopy_dataset.py new file mode 100755 index 0000000000..f27360ad0b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/multicopy_dataset.py @@ -0,0 +1,243 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class MulticopyDataset(torch.utils.data.Dataset): + """ + This is slightly modified from lhotse's K2SpeechRecognitionDataset, but + modified as suggested by Piotr in this github thread: + https://github.com/k2-fsa/icefall/pull/1975 + + If cut_transforms is specified, which will normally be the case for training + data, where you might specify Musan augmentation, it returns two copies of + the data that differ only in the augmentations, followed by a third unmodified + copy. The structure of the data would be [ a b c d a b c d a b c d ], i.e. + the order is: first copy of all buts, second copy of all cuts, unmodified + copy of all cuts. + If cut_transforms is not specified, this dataset behaves like lhotse's regular + K2SpeechRecognitionDataset. + The yielded dict will have an extra key called "num_copies", set to 3 if + we did the 2 augmentation copies plus one original copy, or 1 if there + were no augmentations. + + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) From 0f7a7530767901489a146b2c7729d4f666b34003 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Dec 2025 11:54:33 +0800 Subject: [PATCH 0750/1191] Revert "Decrase query-head-dim and value-head-dim from 64 to 32." This reverts commit d1ba202315e5c8f8ccf92b7afd7ba0b640d91c3d. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 3b11590f57..abeb950501 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -234,14 +234,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--query-head-dim", type=str, - default="32", + default="64", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) parser.add_argument( "--value-head-dim", type=str, - default="32", + default="64", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From 5e97b892afc8b058d37976dc6dba6f550b5b6971 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Dec 2025 12:32:48 +0800 Subject: [PATCH 0751/1191] Reorganize data module stuff, support --use-giga --- .../ASR/zapformer/asr_datamodule.py | 23 ++++++++++++-- egs/librispeech/ASR/zapformer/train.py | 31 +++++++++++++++---- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index edf7c35ae5..94e7e9464f 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -32,7 +32,7 @@ PrecomputedFeatures, SimpleCutSampler, ) -# MulticopyDataset is a modified version of one from +# MulticopyDataset is a modified version of K2SpeechRecognitionDataset from # lhotse.dataset, modified to, in training mode, to return a batch that has 3 # different copies of the same data with the last two having different Musan # augmentations and the first having none; and also include the key "num_copies" @@ -98,8 +98,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--full-libri", type=str2bool, default=True, - help="""Used only when --mini-libri is False.When enabled, - use 960h LibriSpeech. Otherwise, use 100h subset.""", + help="""When enabled, use 960h LibriSpeech; and 10000 hour GigaSpeech if --use-gigs. Otherwise, use 100h and if applicable 250h subsets.""", ) group.add_argument( "--mini-libri", @@ -210,6 +209,24 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="AudioSamples or PrecomputedFeatures", ) + parser.add_argument( + "--libri-copies", + type=int, + default=1, + help="If set to <= 0, we use only librispeech (CAUTION: this may be surprising). If set to > 0, every epoch means one epoch " + "of gigaspeech and libri_copies epochs of librispeech (although it is really libri_copies times 3, because of Librispeech " + "using speed augmentation." + ) + + parser.add_argument( + "--use-giga", + type=str2bool, + default=False, + help="If set to True, use gigaspeech in addition to librispeech. See also --libri-copies." + ) + + + def train_dataloaders( self, cuts_train: CutSet, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index abeb950501..2d9b40ed33 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -66,7 +66,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule, LibriSpeech, GigaSpeech from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner @@ -437,7 +437,7 @@ def get_parser(): ) parser.add_argument( - "--exp-dir", + "--exp-dir", type=str, default="zipformer/exp", help="""The experiment dir. @@ -1425,10 +1425,13 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(args.manifest_dir) + gigaspeech = GigaSpeech(args.manifest_dir) # gigaspeech will only be used if --libri-copies set. this is not a typo! if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() + train_cuts_len = 960.0 * 3 # 960 hours times 3 for augmentation # previously we used the following code to load all training cuts, # strictly speaking, shuffled training cuts should be used instead, @@ -1440,6 +1443,22 @@ def run(rank, world_size, args): # train_cuts += librispeech.train_other_500_cuts() else: train_cuts = librispeech.train_clean_100_cuts() + train_cuts_len = 100.0 * 3 # 100 hours times 3 for augmentation + + if params.use_giga: + if params.full_libri: + gigaspeech_cuts = gigaspeech.train_XL_cuts() + gigaspeech_cuts_len = 10000.0 + else: + gigaspeech_cuts = gigaspeech.train_S_cuts() # e.g. for debugging + gigaspeech_cuts_len = 250.0 + if params.libri_copies > 1: + train_cuts = train_cuts.repeat(params.libri_copies) + train_cuts_len = train_cuts_len * params.libri_copies + datsets_and_weights = [ (train_cuts, train_cuts_len), + (gigaspeech_cuts, gigaspeech_cuts_len) ] + cuts, weights = zip(datasets_and_weights) + train_cuts = CutSet.mux(*cuts, weights=weights) def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1487,13 +1506,13 @@ def remove_short_and_long_utt(c: Cut): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = asr_datamodule.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) if not params.print_diagnostics and False: scan_pessimistic_batches_for_oom( @@ -1640,7 +1659,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) From 5e7d3c8cd506671276410f98dc20265dee901018 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Dec 2025 13:07:21 +0800 Subject: [PATCH 0752/1191] Fixes for gigaspeech --- .../ASR/zapformer/asr_datamodule.py | 27 +++++++++++-------- egs/librispeech/ASR/zapformer/train.py | 6 ++--- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 94e7e9464f..bc77784f80 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -479,16 +479,16 @@ def __init__(self, manifest_dir: str): It is expected to contain the following files: - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz - - gigaspeech_cuts_L_raw.jsonl.gz - - gigaspeech_cuts_M_raw.jsonl.gz - - gigaspeech_cuts_S_raw.jsonl.gz - - gigaspeech_cuts_XS_raw.jsonl.gz - - gigaspeech_cuts_DEV_raw.jsonl.gz - - gigaspeech_cuts_TEST_raw.jsonl.gz + - gigaspeech_cuts_L.jsonl.gz + - gigaspeech_cuts_M.jsonl.gz + - gigaspeech_cuts_S.jsonl.gz + - gigaspeech_cuts_XS.jsonl.gz + - gigaspeech_cuts_DEV.jsonl.gz + - gigaspeech_cuts_TEST.jsonl.gz """ self.manifest_dir = Path(manifest_dir) - def train_XL_cuts(self) -> CutSet: + def train_XL_cuts_split(self) -> CutSet: logging.info("About to get train-XL cuts") filenames = list( @@ -507,23 +507,28 @@ def train_XL_cuts(self) -> CutSet: return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + def train_XL_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XL.jsonl.gz" + logging.info(f"About to get train-XL cuts from {f}") + return CutSet.from_jsonl_lazy(f) + def train_L_cuts(self) -> CutSet: - f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" + f = self.manifest_dir / "gigaspeech_cuts_L.jsonl.gz" logging.info(f"About to get train-L cuts from {f}") return CutSet.from_jsonl_lazy(f) def train_M_cuts(self) -> CutSet: - f = self.manifest_dir / "gigaspeech_cuts_M_raw.jsonl.gz" + f = self.manifest_dir / "gigaspeech_cuts_M.jsonl.gz" logging.info(f"About to get train-M cuts from {f}") return CutSet.from_jsonl_lazy(f) def train_S_cuts(self) -> CutSet: - f = self.manifest_dir / "gigaspeech_cuts_S_raw.jsonl.gz" + f = self.manifest_dir / "gigaspeech_cuts_S.jsonl.gz" logging.info(f"About to get train-S cuts from {f}") return CutSet.from_jsonl_lazy(f) def train_XS_cuts(self) -> CutSet: - f = self.manifest_dir / "gigaspeech_cuts_XS_raw.jsonl.gz" + f = self.manifest_dir / "gigaspeech_cuts_XS.jsonl.gz" logging.info(f"About to get train-XS cuts from {f}") return CutSet.from_jsonl_lazy(f) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2d9b40ed33..78eb765911 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -70,7 +70,7 @@ from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -1455,9 +1455,9 @@ def run(rank, world_size, args): if params.libri_copies > 1: train_cuts = train_cuts.repeat(params.libri_copies) train_cuts_len = train_cuts_len * params.libri_copies - datsets_and_weights = [ (train_cuts, train_cuts_len), + datasets_and_weights = [ (train_cuts, train_cuts_len), (gigaspeech_cuts, gigaspeech_cuts_len) ] - cuts, weights = zip(datasets_and_weights) + cuts, weights = zip(*datasets_and_weights) train_cuts = CutSet.mux(*cuts, weights=weights) def remove_short_and_long_utt(c: Cut): From 71f6808fc89c4982837ed90f1d4ae0815abc9b26 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Dec 2025 16:43:07 +0800 Subject: [PATCH 0753/1191] Fix bug with indentation that meant it only used giga if not full_libri. --- egs/librispeech/ASR/zapformer/train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 78eb765911..66f886fc85 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1452,13 +1452,14 @@ def run(rank, world_size, args): else: gigaspeech_cuts = gigaspeech.train_S_cuts() # e.g. for debugging gigaspeech_cuts_len = 250.0 - if params.libri_copies > 1: - train_cuts = train_cuts.repeat(params.libri_copies) - train_cuts_len = train_cuts_len * params.libri_copies - datasets_and_weights = [ (train_cuts, train_cuts_len), - (gigaspeech_cuts, gigaspeech_cuts_len) ] - cuts, weights = zip(*datasets_and_weights) - train_cuts = CutSet.mux(*cuts, weights=weights) + + if params.libri_copies > 1: + train_cuts = train_cuts.repeat(params.libri_copies) + train_cuts_len = train_cuts_len * params.libri_copies + datasets_and_weights = [ (train_cuts, train_cuts_len), + (gigaspeech_cuts, gigaspeech_cuts_len) ] + cuts, weights = zip(*datasets_and_weights) + train_cuts = CutSet.mux(*cuts, weights=weights) def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds From 75ed4fea68da68fba6e68f18df4abf1818e3ad34 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 2 Dec 2025 13:26:00 +0800 Subject: [PATCH 0754/1191] Fix decode script --- egs/librispeech/ASR/zapformer/decode.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 221f01297b..78fb015840 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -106,7 +106,7 @@ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import LibriSpeech, AsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -778,7 +778,7 @@ def save_wer_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -1040,17 +1040,18 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + librispeech = LibriSpeech(args.manifest_dir) + asr_datamodule = AsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() dev_clean_cuts = librispeech.dev_clean_cuts() dev_other_cuts = librispeech.dev_other_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) - dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + dev_clean_dl = asr_datamodule.test_dataloaders(dev_clean_cuts) + dev_other_dl = asr_datamodule.test_dataloaders(dev_other_cuts) test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] From 6d2618b2201199c063ec3afb57feaddb8a68517b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Dec 2025 21:28:47 +0800 Subject: [PATCH 0755/1191] implement rope instead of rel pos emb --- egs/librispeech/ASR/zapformer/train.py | 17 - egs/librispeech/ASR/zipformer/train.py | 18 +- egs/librispeech/ASR/zipformer/zipformer.py | 501 ++++++++++++++------- 3 files changed, 329 insertions(+), 207 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 66f886fc85..2dc10685b1 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -245,20 +245,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - parser.add_argument( "--conv-params", type=str, @@ -714,9 +700,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_encoder_layers=lookup(params, "num_encoder_layers"), encoder_dim=lookup(params, "encoder_dim"), query_head_dim=lookup(params, "query_head_dim"), - pos_head_dim=lookup(params, "pos_head_dim"), value_head_dim=lookup(params, "value_head_dim"), - pos_dim=params.pos_dim, num_heads=lookup(params, "num_heads"), feedforward_multiple=lookup(params, "feedforward_multiple"), conv_params=lookup(params, "conv_params"), @@ -1374,7 +1358,6 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - assert params.use_ctc # for now, require CTC, we may remove this requirement later. spec_augment = ExpAugment() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 367e06da83..7d2c351365 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -235,7 +235,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--query-head-dim", type=str, - default="32", + default="48", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) @@ -246,20 +246,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - parser.add_argument( "--cnn-module-kernel", type=str, @@ -730,9 +716,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_encoder_layers=lookup(params, "num_encoder_layers"), encoder_dim=lookup(params, "encoder_dim"), query_head_dim=lookup(params, "query_head_dim"), - pos_head_dim=lookup(params, "pos_head_dim"), value_head_dim=lookup(params, "value_head_dim"), - pos_dim=params.pos_dim, num_heads=lookup(params, "num_heads"), feedforward_multiple=lookup(params, "feedforward_multiple"), cnn_module_kernel=lookup(params, "cnn_module_kernel"), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3115918969..f12ed03571 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -133,7 +133,6 @@ def _to_tuple(x): self.num_encoder_layers = num_encoder_layers self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) self.num_heads = num_heads = _to_tuple(num_heads) feedforward_multiple = _to_tuple(feedforward_multiple) self.conv_params = conv_params = _to_tuple(conv_params) @@ -157,10 +156,8 @@ def _to_tuple(x): for i in range(num_encoders): encoder_layer = Zipformer2EncoderLayer( embed_dim=encoder_dim[i], - pos_dim=pos_dim, num_heads=num_heads[i], query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], value_head_dim=value_head_dim[i], feedforward_multiple=feedforward_multiple[i], conv_params=conv_params[i], @@ -507,16 +504,13 @@ class Zipformer2EncoderLayer(nn.Module): Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) + >>> out = encoder_layer(src) """ def __init__( self, embed_dim: int, - pos_dim: int, num_heads: int, query_head_dim: int, - pos_head_dim: int, value_head_dim: int, feedforward_multiple: int, conv_params: int, @@ -532,12 +526,10 @@ def __init__( self.offset_correlation_limiter = CorrelationLimiter() - self.self_attn_weights = RelPositionMultiheadAttentionWeights( + self.self_attn_weights = MultiheadAttentionWeights( embed_dim, - pos_dim=pos_dim, num_heads=num_heads, query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, ) self.self_attn = SelfAttention(embed_dim, num_heads, value_head_dim) @@ -555,7 +547,6 @@ def __init__( def forward( self, src: Tensor, - pos_emb: Tensor, chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, @@ -565,7 +556,6 @@ def forward( Pass the input through the encoder layer. Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). @@ -585,7 +575,6 @@ def forward( # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( src, - pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale, @@ -618,7 +607,6 @@ def forward( def streaming_forward( self, src: Tensor, - pos_emb: Tensor, cached_key: Tensor, cached_nonlin_attn: Tensor, cached_val1: Tensor, @@ -631,8 +619,6 @@ def streaming_forward( Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) cached_key: cached attention key tensor of left context, of shape (left_context_len, batch_size, key_dim) cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape @@ -661,7 +647,6 @@ def streaming_forward( # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights, cached_key = self.self_attn_weights.streaming_forward( src, - pos_emb=pos_emb, cached_key=cached_key, left_context_len=left_context_len, key_padding_mask=src_key_padding_mask, @@ -748,17 +733,6 @@ def __init__( self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) self.proj.lr_scale = 0.75 - # scale up the position weights, this is to fix an issue with the - # linear_pos projections otherwise needing to have too-large scale, larger - # than the "default scale" used in AdamW-like - # log-weight decay in TransformedAdam. The issue we are trying - # to solve is that between different runs, the linear_pos projections of - # different self_attn_weights modules get very different scales.. the - # thinking is that sometimes if one of these linear_pos projections has - # a too-small scale, it never "learns something useful". - self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, length_factor=1.0, - ) self.name = None self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -800,8 +774,6 @@ def forward( (out, out_sd), both of the same shape as src, where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. """ - pos_emb = self.encoder_pos(src) - src_orig_fulldim = src src = self.proj(src) # project to layer dim. @@ -817,7 +789,6 @@ def forward( for i, mod in enumerate(self.layers): src = mod( src, - pos_emb, chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, @@ -861,7 +832,6 @@ def streaming_forward( - output, a Tensor with the same shape as src. - updated states """ - pos_emb = self.encoder_pos(src, left_context_len) num_channels = src.shape[-1] layer_dim = self.layers[0].embed_dim if num_channels > layer_dim: @@ -885,7 +855,6 @@ def streaming_forward( new_cached_conv, ) = mod.streaming_forward( src, - pos_emb, cached_key=cached_key, cached_nonlin_attn=cached_nonlin_attn, cached_val1=cached_val1, @@ -945,119 +914,353 @@ def forward(self, src_orig: Tensor, src: Tensor): -class CompactRelPositionalEncoding(torch.nn.Module): +# taken from torchtune. +class RotaryPositionalEmbeddings(nn.Module): """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the Fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) + This class implements Rotary Positional Embeddings (RoPE) + proposed in https://arxiv.org/abs/2104.09864. + + Reference implementation (used for correctness verfication) + can be found here: + https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 + + In this implementation we cache the embeddings for each position upto + ``max_seq_len`` by computing this during init. + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim // num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + """ + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self.rope_init() + + def reset_parameters(self): + self.rope_init() + + def rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward( + self, x: torch.Tensor, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape + ``[b, s, n_h, h_d]`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) + +class MultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with additive relative-position + scores that are kept separate from the regular scores. + + relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. Args: - embed_dim: Embedding dimension. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. """ - def __init__( self, embed_dim: int, - max_len: int = 1000, - length_factor: float = 1.0, + num_heads: int, + query_head_dim: int, + dropout: float = 0.0, ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() + super().__init__() self.embed_dim = embed_dim - assert embed_dim % 2 == 0, embed_dim - self.pe = None - assert length_factor >= 1.0, length_factor - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) + self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 + + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + q = self.copy_query(q) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + q = self.rope(q.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + k = self.rope(k.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 0, 1, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 0, 3, 1) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, 0.1 * aux_loss_scale, + key_padding_mask, self.name) - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) - self.pe = pe.to(dtype=x.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. + return attn_weights + def streaming_forward( + self, + x: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. + x: input of shape (seq_len, batch_size, embed_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return pos_emb + x = self.in_proj(x) + query_head_dim = self.query_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, 0.1 * aux_loss_scale, + key_padding_mask, self.name) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + class PenalizeLargeAttentionScores(torch.autograd.Function): @@ -1110,7 +1313,6 @@ def backward( - class RelPositionMultiheadAttentionWeights(nn.Module): r"""Module that computes multi-head attention weights with relative position encoding. Various other modules consume the resulting attention weights: see, for example, the @@ -1129,7 +1331,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on any given call to forward(), in training time. """ - def __init__( self, embed_dim: int, @@ -1193,7 +1394,6 @@ def forward( """ x = self.in_proj(x) query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim num_heads = self.num_heads seq_len, batch_size, _ = x.shape @@ -1307,7 +1507,6 @@ def forward( def streaming_forward( self, x: Tensor, - pos_emb: Tensor, cached_key: Tensor, left_context_len: int, key_padding_mask: Tensor, @@ -1315,7 +1514,6 @@ def streaming_forward( r""" Args: x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) cached_key: cached attention key tensor of left context, of shape (left_context_len, batch_size, key_dim) left_context_len: number of left context frames. @@ -1329,7 +1527,6 @@ def streaming_forward( """ x = self.in_proj(x) query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim num_heads = self.num_heads seq_len, batch_size, _ = x.shape @@ -1339,9 +1536,6 @@ def streaming_forward( # self-attention q = x[..., 0:query_dim] k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim # Pad cached left contexts assert cached_key.shape[0] == left_context_len, ( @@ -1356,53 +1550,14 @@ def streaming_forward( k_len = k.shape[0] q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) k = k.reshape(k_len, batch_size, num_heads, query_head_dim) # time1 refers to target, time2 refers to source. q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) attn_scores = torch.matmul(q, k) - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(k_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, k_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - assert attn_scores.shape == ( num_heads, batch_size, @@ -1441,7 +1596,7 @@ def _print_attn_entropy(self, attn_weights: Tensor): class SelfAttention(nn.Module): """ The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + weights, e.g. as computed by MultiheadAttentionWeights. Args: embed_dim: the input and output embedding dimension From 3e01ec8b30b9c8919f98aaca675ed053646c33ae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 3 Dec 2025 22:32:07 +0800 Subject: [PATCH 0756/1191] Introduce multiple == 4 into rope cache so that each base freq is represented times 1,2,3,4. --- egs/librispeech/ASR/zipformer/zipformer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f12ed03571..0e74637a93 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -949,10 +949,14 @@ def reset_parameters(self): self.rope_init() def rope_init(self): - theta = 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) - ) + multiple = 4 # have multiples 1,2,3,4 of each frequency + # theta is inverse angular frequencies + assert self.dim % (2 * multiple) == 0 + D = self.dim // (2 * multiple) + freqs = (self.base // multiple) ** torch.linspace(0., 1., D) + freqs = freqs * torch.arange(1, multiple + 1).unsqueeze(1) + freqs = freqs.flatten() + theta = 1.0 / freqs self.register_buffer("theta", theta, persistent=False) self.build_rope_cache(self.max_seq_len) From 8228bf0b094cfccba17b721e03cdb3c3f7985640 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Dec 2025 14:51:16 +0800 Subject: [PATCH 0757/1191] Version of rope that uses multiples 1 and 3 / 4 of each freq. --- egs/librispeech/ASR/zipformer/zipformer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0e74637a93..928a7f8f59 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -949,15 +949,18 @@ def reset_parameters(self): self.rope_init() def rope_init(self): - multiple = 4 # have multiples 1,2,3,4 of each frequency - # theta is inverse angular frequencies - assert self.dim % (2 * multiple) == 0 - D = self.dim // (2 * multiple) - freqs = (self.base // multiple) ** torch.linspace(0., 1., D) - freqs = freqs * torch.arange(1, multiple + 1).unsqueeze(1) - freqs = freqs.flatten() - theta = 1.0 / freqs - self.register_buffer("theta", theta, persistent=False) + multiples = [ 1, 4. / 3. ] # for each frequency have f and 4 / 3 * f + assert self.dim % (2 * len(multiples)) == 0 # e.g. self.dim == 64. head dim. + D = self.dim // (2 * len(multiples)) # e.g. D == 16. + + + inv_freqs = (2. ** torch.arange(D)) # [ 1, 2, 4, ... ] + inv_freqs = torch.cat([ m * inv_freqs for m in multiples ], dim=0) + + angular_freqs = math.pi / inv_freqs + # so highest angular_freq is pi, which means flipping between -1 and 1 on alternate tokens. this is + # the nyquist. + self.register_buffer("theta", angular_freqs, persistent=False) self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: From 560d4995b58b158420e555bfe1d587b13a696bc0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Dec 2025 14:54:00 +0800 Subject: [PATCH 0758/1191] Code cleanup --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 928a7f8f59..313b4cd231 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -930,18 +930,14 @@ class RotaryPositionalEmbeddings(nn.Module): head in the attention module computed as ``embed_dim // num_heads`` max_seq_len (int): Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed - base (int): The base for the geometric progression used to compute - the rotation angles """ def __init__( self, dim: int, max_seq_len: int = 4096, - base: int = 10_000, ) -> None: super().__init__() self.dim = dim - self.base = base self.max_seq_len = max_seq_len self.rope_init() @@ -1080,7 +1076,7 @@ def __init__( bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 + self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096 self.copy_query = Identity() From 00dc3fcb2ea9cff1f9fed4f8afd3dc4beab716ce Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 4 Dec 2025 15:22:21 +0800 Subject: [PATCH 0759/1191] Have four different multiples in rope, and increase query-head-dim from 64 to 128 so largest inv-freq is still pi / 65536. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2dc10685b1..85d5472ed4 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -234,7 +234,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--query-head-dim", type=str, - default="64", + default="128", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 313b4cd231..693d93d0be 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -945,8 +945,9 @@ def reset_parameters(self): self.rope_init() def rope_init(self): - multiples = [ 1, 4. / 3. ] # for each frequency have f and 4 / 3 * f - assert self.dim % (2 * len(multiples)) == 0 # e.g. self.dim == 64. head dim. + # these multiples are on the inverse frequences, so on frequencies the multiples would be the inverses of these. + multiples = [ 1., 4. / 3., 8. / 5., 8. / 7. ] + assert self.dim % (2 * len(multiples)) == 0 # e.g. self.dim == 128. head dim. D = self.dim // (2 * len(multiples)) # e.g. D == 16. From a56a89a069f1e26266383e873b8dfb8c9d005df1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Nov 2025 15:38:46 +0800 Subject: [PATCH 0760/1191] Add gating to SelfAttention module # Conflicts: # egs/librispeech/ASR/zipformer/zipformer.py --- egs/librispeech/ASR/zipformer/zipformer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 693d93d0be..3e7e5ff632 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -532,7 +532,7 @@ def __init__( query_head_dim=query_head_dim, ) - self.self_attn = SelfAttention(embed_dim, num_heads, value_head_dim) + self.self_attn = GatedSelfAttention(embed_dim, num_heads, value_head_dim) feedforward_dim = embed_dim * feedforward_multiple self.feed_forward1 = FeedforwardModule(embed_dim, feedforward_dim) @@ -1597,9 +1597,9 @@ def _print_attn_entropy(self, attn_weights: Tensor): ) -class SelfAttention(nn.Module): +class GatedSelfAttention(nn.Module): """ - The simplest possible attention module. This one works with already-computed attention + Self-attention module with sigmoid gating. This one works with already-computed attention weights, e.g. as computed by MultiheadAttentionWeights. Args: @@ -1607,7 +1607,6 @@ class SelfAttention(nn.Module): num_heads: the number of attention heads value_head_dim: the value dimension per head """ - def __init__( self, embed_dim: int, @@ -1615,8 +1614,10 @@ def __init__( value_head_dim: int, ) -> None: super().__init__() - self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, - bias=True, out_groups=num_heads) + self.in_proj = ScaledLinear(embed_dim, 2 *num_heads * value_head_dim, + bias=True) + + self.sigmoid = nn.Sigmoid() self.out_proj = ScaledLinear( num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 @@ -1648,7 +1649,8 @@ def forward( num_heads = attn_weights.shape[0] assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = self.in_proj(x) # (seq_len, batch_size, 2 * num_heads * value_head_dim) + x, s = x.chunk(2, dim=-1) x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) # now x: (num_heads, batch_size, seq_len, value_head_dim) value_head_dim = x.shape[-1] @@ -1664,6 +1666,7 @@ def forward( ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = x * self.sigmoid(s) x = self.out_proj(x) return x From d2151985614364f8867918b7a4ca409317c10860 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Dec 2025 12:58:19 +0800 Subject: [PATCH 0761/1191] Double frequency resolution of conv module without changing conv_params. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3e7e5ff632..598e780f2a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1775,8 +1775,8 @@ def __init__(self, bias: bool = True): super().__init__() self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) - # the factor of 2 is for (sin, cos) - self.weight_proj = nn.Linear(params_per_channel, 2 * params_per_channel) + # one factor of 2 is for (sin, cos); the other is to double the num representable freqs + self.weight_proj = nn.Linear(params_per_channel, 4 * params_per_channel) if bias: self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) From a7d4fff414e7dafd6772c66ebdb5dc16bdb01d16 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Dec 2025 23:21:56 +0800 Subject: [PATCH 0762/1191] Documentation changes. --- egs/librispeech/ASR/zipformer/zipformer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 598e780f2a..66413da0d8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -101,7 +101,7 @@ def __init__( downsampling_factor: Tuple[int] = (2, 4), encoder_dim: Union[int, Tuple[int]] = 384, num_encoder_layers: Union[int, Tuple[int]] = 4, - query_head_dim: Union[int, Tuple[int]] = 24, + query_head_dim: Union[int, Tuple[int]] = 64, pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, @@ -937,6 +937,7 @@ def __init__( max_seq_len: int = 4096, ) -> None: super().__init__() + assert dim in [64, 128] self.dim = dim self.max_seq_len = max_seq_len self.rope_init() @@ -946,11 +947,15 @@ def reset_parameters(self): def rope_init(self): # these multiples are on the inverse frequences, so on frequencies the multiples would be the inverses of these. - multiples = [ 1., 4. / 3., 8. / 5., 8. / 7. ] + # it's the frequencies we want to be exact multiples of each other. + if self.dim == 64: + multiples = [ 1., 4. / 3. ] + else: + assert self.dim == 128 + multiples = [ 1., 4. / 3., 8. / 5., 8. / 7. ] assert self.dim % (2 * len(multiples)) == 0 # e.g. self.dim == 128. head dim. D = self.dim // (2 * len(multiples)) # e.g. D == 16. - inv_freqs = (2. ** torch.arange(D)) # [ 1, 2, 4, ... ] inv_freqs = torch.cat([ m * inv_freqs for m in multiples ], dim=0) From 5c4dabb7dbdb19d6c5f741065560723935ad7484 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 10 Dec 2025 10:18:31 +0800 Subject: [PATCH 0763/1191] Change conv_params from 32 to 32,32,16,32 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 85d5472ed4..463b72f21d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -248,7 +248,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--conv-params", type=str, - default="32", + default="32,32,16,32", help="Parameters per channel of convolution kernels", ) From e5d113d193d8ec36884bb2adb34f370f058a81c9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 16 Dec 2025 19:42:20 +0800 Subject: [PATCH 0764/1191] Take decoding script updates from 1766. --- egs/librispeech/ASR/zapformer/ctc_decode.py | 1277 ++++++++++++++++++- egs/librispeech/ASR/zapformer/decode.py | 117 +- 2 files changed, 1377 insertions(+), 17 deletions(-) mode change 120000 => 100755 egs/librispeech/ASR/zapformer/ctc_decode.py diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py deleted file mode 120000 index a78e5c1df0..0000000000 --- a/egs/librispeech/ASR/zapformer/ctc_decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py new file mode 100755 index 0000000000..f3bce1b43d --- /dev/null +++ b/egs/librispeech/ASR/zapformer/ctc_decode.py @@ -0,0 +1,1276 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(3) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(4) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(5) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(6) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring + +(7) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(8) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeech, GigaSpeech, AsrDataModule +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.context_graph import ContextGraph, ContextState +from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.lm_wrapper import LmScorer +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: # only used for gigaspeech + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="If True, decode gigaspeech in addition to librispeech test sets." + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (4) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (5) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. + - (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + - (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with + the given beam, rescore them with the attention decoder. + - (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during + beam search, LODR and hotwords are also supported in this decoding method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--nnlm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--nnlm-scale", + type=float, + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "beam": 4, # for prefix-beam-search + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + device = params.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} # note: returns words + + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no-rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} # note: returns BPE tokens + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + if params.giga: + results = post_processing(results) + + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + + test_set_wers = dict() + for key, results in results_dict.items(): + if params.giga: + results = post_processing(results) + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 + + if params.decoding_method in [ + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "attention-decoder-rescoring-no-ngram", + ]: + HLG = None + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method in [ + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + # only load the neural network LM if required + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, + params=params, + device=device, + lm_scale=params.nnlm_scale, + ) + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + + asr_datamodule = AsrDataModule(args) + test_sets = [] + test_dl = [] + if True: + librispeech = LibriSpeech(args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + dev_clean_dl = asr_datamodule.test_dataloaders(dev_clean_cuts) + dev_other_dl = asr_datamodule.test_dataloaders(dev_other_cuts) + + test_sets += ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl += [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + giga_test_dl = asr_datamodule.test_dataloaders(test_cuts) + giga_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["dev", "test"] + test_dl += [giga_test_dl, giga_dev_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 78fb015840..841ff0142b 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -106,7 +106,7 @@ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeech, AsrDataModule +from asr_datamodule import LibriSpeech, GigaSpeech, AsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -142,6 +142,66 @@ LOG_EPS = math.log(1e-10) +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: # only used for gigaspeech + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + def get_parser(): parser = argparse.ArgumentParser( @@ -378,6 +438,13 @@ def get_parser(): help="""Skip scoring, but still save the ASR output (for eval sets).""", ) + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="""If True, decode gigaspeech in addition to librispeech test sets.""", + ) + add_model_arguments(parser) return parser @@ -716,6 +783,7 @@ def decode_dataset( batch_str = f"{batch_idx}/{num_batches}" logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results @@ -730,8 +798,10 @@ def save_asr_output( for key, results in results_dict.items(): recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) + if params.giga: + results = post_processing(results) + store_transcripts(filename=recogs_filename, texts=results) logging.info(f"The transcripts are stored in {recogs_filename}") @@ -747,6 +817,9 @@ def save_wer_results( """ test_set_wers = dict() for key, results in results_dict.items(): + if params.giga: + results = post_processing(results) + # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" @@ -1040,21 +1113,33 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeech(args.manifest_dir) asr_datamodule = AsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - dev_clean_cuts = librispeech.dev_clean_cuts() - dev_other_cuts = librispeech.dev_other_cuts() - - test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) - test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) - dev_clean_dl = asr_datamodule.test_dataloaders(dev_clean_cuts) - dev_other_dl = asr_datamodule.test_dataloaders(dev_other_cuts) - - test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] - test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + test_sets = [] + test_dl = [] + if True: # if not args.giga: + librispeech = LibriSpeech(args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + dev_clean_dl = asr_datamodule.test_dataloaders(dev_clean_cuts) + dev_other_dl = asr_datamodule.test_dataloaders(dev_other_cuts) + + test_sets += ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl += [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + giga_test_dl = asr_datamodule.test_dataloaders(test_cuts) + giga_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["dev", "test"] + test_dl += [giga_test_dl, giga_dev_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( From 1436c4ab61bf9ecf2674e049a40121fa42641637 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 13:12:41 +0800 Subject: [PATCH 0765/1191] Disable time warping --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index f6735a65c4..2bd2077caf 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -435,7 +435,7 @@ def forward( B = batch_size // num_copies x = x.reshape(num_copies, B, seq_len, num_channels) - do_time_warp = True + do_time_warp = False if do_time_warp: # Apply time warping. First append the copies on the channel # dimension so all copies get the exact same time-warping. From 78eeebdc074f371630e4a44ae2b0c9a9f47c6405 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 15:06:01 +0800 Subject: [PATCH 0766/1191] Code cleanups in zipformer.py --- egs/librispeech/ASR/zipformer/zipformer.py | 304 +-------------------- 1 file changed, 7 insertions(+), 297 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 66413da0d8..a90c2ee742 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -72,17 +72,12 @@ class Zipformer2(EncoderInterface): num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack query_head_dim (int or Tuple[int]): dimension of query and key per attention head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head value_head_dim (int or Tuple[int]): dimension of value in each attention head num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules conv_params (int or Tuple[int])): Kernel size of convolution module - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - causal (bool): if True, support chunkwise causal convolution. This should not hurt WER as no modeling power is lost, but the convolution modules will be slightly slower and use more memory. Enables use of the chunk_size and @@ -102,12 +97,10 @@ def __init__( encoder_dim: Union[int, Tuple[int]] = 384, num_encoder_layers: Union[int, Tuple[int]] = 4, query_head_dim: Union[int, Tuple[int]] = 64, - pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, feedforward_multiple: Union[int, Tuple[int]] = 4, conv_params: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], @@ -170,7 +163,6 @@ def _to_tuple(x): encoder_layer, num_encoder_layers[i], dim=downsampling_factor[i]*input_dim, - pos_dim=pos_dim, ) encoders.append(encoder) @@ -271,20 +263,14 @@ def forward( ) x = upsample_by(x, ds) - - assert self.output_downsampling_factor == 2, self.output_downsampling_factor od = self.output_downsampling_factor x = downsample_by(x, od) x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 + if od > 1: + x_lens = (x_lens + od - 1) // od - return x, lengths + return x, x_lens def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -456,6 +442,8 @@ def pad_mask(mask: Optional[Tensor], seq_len: int): def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: # x: (seq_len, batch_size, num_channels) # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + if downsampling_factor == 1: + return x (seq_len, batch_size, num_channels) = x.shape x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) x = x.permute(0, 2, 1, 3) @@ -465,6 +453,8 @@ def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: # x: (seq_len, batch_size, num_channels) # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + if upsampling_factor == 1: + return x (seq_len, batch_size, num_channels) = x.shape x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) x = x.permute(0, 2, 1, 3) @@ -712,7 +702,6 @@ class Zipformer2Encoder(nn.Module): encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). dim: the dimension of the input and output (layer dim may be less than this). - pos_dim: the dimension for the relative positional encoding Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) @@ -725,7 +714,6 @@ def __init__( encoder_layer: nn.Module, num_layers: int, dim: int, - pos_dim: int, ) -> None: super().__init__() @@ -1322,284 +1310,6 @@ def backward( -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - assert pos_head_dim <= query_head_dim - self.pos_head_dim = pos_head_dim - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, - bias=True, initial_scale=0.125 * query_head_dim**-0.25 - ) - - - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnostics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - self.copy_key = Identity() - - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - aux_loss_scale: float = 0.0, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.copy_key(k) - - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = q[..., :pos_head_dim] - p = self.copy_pos_query(p) # diagnostics only - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - # attn_scores: (head, batch, query_time, key_time) - - if True: - # position scores. - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_head_dim, seq_len2) - - if self.training: - pe = pos_emb.expand(num_heads, batch_size, pos_head_dim, seq_len2) - pe = pe.reshape(num_heads * batch_size, pos_head_dim, seq_len2).permute(0, 2, 1) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. - attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, 0.1 * aux_loss_scale, - key_padding_mask, self.name) - - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001: - self._print_attn_entropy(attn_weights) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, ( - cached_key.shape[0], - left_context_len, - ) - k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - assert attn_scores.shape == ( - num_heads, - batch_size, - seq_len, - k_len, - ), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - attn_weights = attn_scores.softmax(dim=-1) - - return attn_weights, cached_key - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) class GatedSelfAttention(nn.Module): From d9760496534539bf1916292b130dac1ca3098595 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 15:08:30 +0800 Subject: [PATCH 0767/1191] normalize rms of embeddings in correlation_loss; divide its aux_loss_scale by 10. --- egs/librispeech/ASR/zipformer/scaling.py | 9 ++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 3 +-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8031f89e20..d7e546452c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1614,9 +1614,16 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 x = x * mask x, y = x.to(torch.float), y.to(torch.float) x, y = x.detach(), y.detach() - x_orig, y_orig = x, y x.requires_grad = True y.requires_grad = True + x_orig, y_orig = x, y + + def norm(x: Tensor): + eps = 1.0e-20 + return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() + + x = norm(x) + y = norm(y) half_batch = batch_size // 2 if half_batch <= 1: # the reason we also return None if half_batch==1 is because of CR-CTC diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a90c2ee742..5e4f0dc370 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -515,7 +515,6 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) self.offset_correlation_limiter = CorrelationLimiter() - self.self_attn_weights = MultiheadAttentionWeights( embed_dim, num_heads=num_heads, @@ -586,7 +585,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask)) + 0.1*aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 83e5b5ec3010d5e5da025e84d25f8d96acfcb572 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 15:43:00 +0800 Subject: [PATCH 0768/1191] Put small scale on correlation if it is large. --- egs/librispeech/ASR/zipformer/scaling.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d7e546452c..b616336912 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1643,12 +1643,17 @@ def norm(x: Tensor): # correlation between tr(M) estimates between elements of the batch. correlation = r[0::2] * r[1::2] - if random.random() < 0.001: + correlation = correlation.mean() + + if random.random() < 0.0001: logging.info( - f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation.mean()}" + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}" ) - correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale / num_channels)) + correlation = correlation.clamp(min=-1., max=1.) + (0.1 * correlation.clamp(min=-10., max=10.)) + (0.01 * correlation.clamp(min=-100., max=100.)) + + correlation.backward(gradient=torch.tensor(aux_loss_scale * half_batch * seq_len, device=correlation.device)) + return x_orig.grad, y_orig.grad, None, None, None From 1ddc66bc5783bd0dff990a488a3054043315b0bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 17:20:51 +0800 Subject: [PATCH 0769/1191] Penalize correlation even if over 100 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b616336912..d673b214cc 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1650,7 +1650,7 @@ def norm(x: Tensor): f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}" ) - correlation = correlation.clamp(min=-1., max=1.) + (0.1 * correlation.clamp(min=-10., max=10.)) + (0.01 * correlation.clamp(min=-100., max=100.)) + correlation = correlation.clamp(min=-1., max=1.) + (0.1 * correlation.clamp(min=-10., max=10.)) + (0.01 * correlation) correlation.backward(gradient=torch.tensor(aux_loss_scale * half_batch * seq_len, device=correlation.device)) From 0d69cf24c4ea88cfdf59ca8e09d60df4a5ca4cb6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 17:42:02 +0800 Subject: [PATCH 0770/1191] different way to deal with large correlations, and divide by sqrt dim --- egs/librispeech/ASR/zipformer/scaling.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d673b214cc..e2798f66c7 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1638,19 +1638,27 @@ def norm(x: Tensor): # r: (batch_size, M, dim) r = torch.matmul(x, r.transpose(1, 2)) # (batch_size, seq_len, m) r = torch.matmul(r.transpose(1, 2), y) # (batch_size, m, dim) - r = r * (1. / seq_len) + r = r * 1. / (seq_len * (num_channels ** 0.5)) + # the summed-over dims in matmuls were num_channels and seq_len but the channel dims can be + # treated as independent not correlated so power of 0.5 # correlation between tr(M) estimates between elements of the batch. correlation = r[0::2] * r[1::2] correlation = correlation.mean() + + + corr_plus = (1 + correlation.abs()) + scale = (1 + corr_plus.log()) / corr_plus + if random.random() < 0.0001: logging.info( - f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}" + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, scaled_correlation={correlation*scale}" ) - correlation = correlation.clamp(min=-1., max=1.) + (0.1 * correlation.clamp(min=-10., max=10.)) + (0.01 * correlation) + + correlation = correlation * scale correlation.backward(gradient=torch.tensor(aux_loss_scale * half_batch * seq_len, device=correlation.device)) From 678785578f25454545040cea80a15c03e6711ca9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 17:59:43 +0800 Subject: [PATCH 0771/1191] Remove factor of 0.1 in aux_loss_scale. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 5e4f0dc370..20fc556d1c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -585,7 +585,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - 0.1*aux_loss_scale, mask=src_key_padding_mask)) + aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 6dbc12992396421ebfa535ff8c239a48022bb019 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 18:59:07 +0800 Subject: [PATCH 0772/1191] Revert removal of factor of 0.1 in correlation limiter. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 20fc556d1c..7172f30f65 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -585,7 +585,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask)) + 0.1 * aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 416ca47c07b988b21404717c096329c60b725542 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 20:17:50 +0800 Subject: [PATCH 0773/1191] Fix bug regarding mask. --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e2798f66c7..11e4d17ec5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1609,9 +1609,6 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 (batch_size, seq_len, num_channels) = x.shape with torch.enable_grad(): with torch.amp.autocast('cuda', enabled=False): - if mask is not None: - mask = (~mask).to(x.dtype).unsqueeze(-1) - x = x * mask x, y = x.to(torch.float), y.to(torch.float) x, y = x.detach(), y.detach() x.requires_grad = True @@ -1621,9 +1618,14 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 def norm(x: Tensor): eps = 1.0e-20 return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() - x = norm(x) y = norm(y) + + if mask is not None: + mask = (~mask).to(x.dtype).unsqueeze(-1) + x = x * mask + y = y * mask + half_batch = batch_size // 2 if half_batch <= 1: # the reason we also return None if half_batch==1 is because of CR-CTC From 7fa3a3b1fdffd83b7ce9635cd60d980810c7526c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 20:29:22 +0800 Subject: [PATCH 0774/1191] Swap src_orig, offset in correlation_limiter --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7172f30f65..1112669ce2 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -584,7 +584,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( - offset.permute(1, 0, 2), src_orig.permute(1, 0, 2), + src_orig.permute(1, 0, 2), offset.permute(1, 0, 2), 0.1 * aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 4e8c397b3f6dce72bea1cd8b3d4d6a4a3c5797cb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 20:43:52 +0800 Subject: [PATCH 0775/1191] Remove factor of 0.1 on aux_loss_scale --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1112669ce2..6d8657fb26 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -585,7 +585,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( src_orig.permute(1, 0, 2), offset.permute(1, 0, 2), - 0.1 * aux_loss_scale, mask=src_key_padding_mask)) + aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 858621ea48fb2edc649040aa6f4da3a6f683aad2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 23:37:48 +0800 Subject: [PATCH 0776/1191] Simplify CorrelationLimiter slightly --- egs/librispeech/ASR/zipformer/scaling.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 11e4d17ec5..3d0d66f6f8 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1647,22 +1647,12 @@ def norm(x: Tensor): # correlation between tr(M) estimates between elements of the batch. correlation = r[0::2] * r[1::2] - correlation = correlation.mean() - - - - corr_plus = (1 + correlation.abs()) - scale = (1 + corr_plus.log()) / corr_plus - if random.random() < 0.0001: logging.info( - f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, scaled_correlation={correlation*scale}" + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation.mean()}" ) - - correlation = correlation * scale - - correlation.backward(gradient=torch.tensor(aux_loss_scale * half_batch * seq_len, device=correlation.device)) + correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale)) return x_orig.grad, y_orig.grad, None, None, None From 9e42487ec198da0053a5e21762d62dead7d8991d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 Dec 2025 23:58:00 +0800 Subject: [PATCH 0777/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3d0d66f6f8..0af42cef87 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1652,7 +1652,7 @@ def norm(x: Tensor): f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation.mean()}" ) - correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale)) + correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale / correlation.numel())) return x_orig.grad, y_orig.grad, None, None, None From 6d006bd58db1bc6794db5c15d87c9cf4fb9a6e13 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 20 Dec 2025 00:24:45 +0800 Subject: [PATCH 0778/1191] Bug fix and multiply by two --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 0af42cef87..663135c023 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1652,7 +1652,7 @@ def norm(x: Tensor): f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation.mean()}" ) - correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale / correlation.numel())) + correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale * batch_size * seq_len / correlation.numel())) return x_orig.grad, y_orig.grad, None, None, None From 7f4dd0956bad42c55868e473d6eab84c8c1ce531 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 20 Dec 2025 13:53:06 +0800 Subject: [PATCH 0779/1191] Double aux_loss_scale of correlation_limiter --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6d8657fb26..4a847641c7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -585,7 +585,7 @@ def forward( offset = with_loss(offset, self.offset_correlation_limiter( src_orig.permute(1, 0, 2), offset.permute(1, 0, 2), - aux_loss_scale, mask=src_key_padding_mask)) + 2. * aux_loss_scale, mask=src_key_padding_mask)) src = src_orig + offset From 89097d6fbe77b0cd3763ea90bf62e4d3c3d0efa9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 21 Dec 2025 14:08:51 +0800 Subject: [PATCH 0780/1191] Make time warp not shared between copies. --- egs/librispeech/ASR/zapformer/model.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index f6735a65c4..b35155becb 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -437,9 +437,13 @@ def forward( do_time_warp = True if do_time_warp: - # Apply time warping. First append the copies on the channel - # dimension so all copies get the exact same time-warping. - x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + shared_time_warp = False + if shared_time_warp: + # Apply time warping. First append the copies on the channel + # dimension so all copies get the exact same time-warping. + x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) + else: + x = x.reshape(num_copies * B, seq_len, num_channels) assert supervision_segments is not None with torch.amp.autocast('cuda', enabled=False): @@ -448,8 +452,12 @@ def forward( time_warp_factor=time_warp_factor, supervision_segments=supervision_segments[:B], ) - x = x.reshape(B, seq_len, num_copies, num_channels) - x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + if shared_time_warp: + x = x.reshape(B, seq_len, num_copies, num_channels) + x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) + else: + x = x.reshape(num_copies, B, seq_len, num_channels) + # x_no_specaug is several repeats of the 1st copy of the data, which # is the one not augmented with Musan. But it does have time From 7d8341be742748de521e3a1a7faa13984091629f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 21 Dec 2025 18:16:12 +0800 Subject: [PATCH 0781/1191] Bug fix --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index b35155becb..cc96669bc4 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -450,7 +450,7 @@ def forward( x = time_warp( x.to(torch.float), time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments[:B], + supervision_segments=supervision_segments[:x.shape[0]], ) if shared_time_warp: x = x.reshape(B, seq_len, num_copies, num_channels) From d9760ae33fb0c7af18ea7094b1b8de469046959d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Dec 2025 11:27:58 +0800 Subject: [PATCH 0782/1191] Increase joiner_multiple from 8 to 12. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 463b72f21d..061f95aa81 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--joiner-multiple", type=int, - default=8, + default=12, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. From 17782cdf6ff8c7221273f9530bd07f1c107af479 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Dec 2025 11:40:19 +0800 Subject: [PATCH 0783/1191] Revert rope frequencies to normal rope. --- egs/librispeech/ASR/zipformer/zipformer.py | 26 +++++++--------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 66413da0d8..18ed3f5009 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -935,10 +935,12 @@ def __init__( self, dim: int, max_seq_len: int = 4096, + base: int = 10_000, ) -> None: super().__init__() assert dim in [64, 128] self.dim = dim + self.base = base self.max_seq_len = max_seq_len self.rope_init() @@ -946,23 +948,11 @@ def reset_parameters(self): self.rope_init() def rope_init(self): - # these multiples are on the inverse frequences, so on frequencies the multiples would be the inverses of these. - # it's the frequencies we want to be exact multiples of each other. - if self.dim == 64: - multiples = [ 1., 4. / 3. ] - else: - assert self.dim == 128 - multiples = [ 1., 4. / 3., 8. / 5., 8. / 7. ] - assert self.dim % (2 * len(multiples)) == 0 # e.g. self.dim == 128. head dim. - D = self.dim // (2 * len(multiples)) # e.g. D == 16. - - inv_freqs = (2. ** torch.arange(D)) # [ 1, 2, 4, ... ] - inv_freqs = torch.cat([ m * inv_freqs for m in multiples ], dim=0) - - angular_freqs = math.pi / inv_freqs - # so highest angular_freq is pi, which means flipping between -1 and 1 on alternate tokens. this is - # the nyquist. - self.register_buffer("theta", angular_freqs, persistent=False) + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: @@ -1082,7 +1072,7 @@ def __init__( bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096 + self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 self.copy_query = Identity() From f320997975b402353bd72973787606674687ce88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Dec 2025 12:50:09 +0800 Subject: [PATCH 0784/1191] Revert joiner-multiple to 8. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 061f95aa81..463b72f21d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--joiner-multiple", type=int, - default=12, + default=8, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. From bc20167acc2f55e7829546fd98b4b1e0636fec05 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Dec 2025 13:09:18 +0800 Subject: [PATCH 0785/1191] Set do_time_warp to True --- egs/librispeech/ASR/zapformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index f0927d5cb1..cc96669bc4 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -435,7 +435,7 @@ def forward( B = batch_size // num_copies x = x.reshape(num_copies, B, seq_len, num_channels) - do_time_warp = False + do_time_warp = True if do_time_warp: shared_time_warp = False if shared_time_warp: From 781bccb08ac5f58ff003c34f73b688109be4fb02 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 22 Dec 2025 11:27:58 +0800 Subject: [PATCH 0786/1191] Increase joiner_multiple from 8 to 12. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 463b72f21d..061f95aa81 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -262,7 +262,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--joiner-multiple", type=int, - default=8, + default=12, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. From 874c5a99cf7022171eb765674b12d23a971b9331 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Dec 2025 21:10:34 +0800 Subject: [PATCH 0787/1191] Do not do random projection in CorrelationLimiter --- egs/librispeech/ASR/zipformer/scaling.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 663135c023..8b400f9fe5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1603,7 +1603,6 @@ def forward(ctx, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Ten @staticmethod def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 x, y = ctx.saved_tensors - dim = x.shape[-1] mask = ctx.mask aux_loss_scale = ctx.aux_loss_scale (batch_size, seq_len, num_channels) = x.shape @@ -1635,17 +1634,15 @@ def norm(x: Tensor): x = x[:2*half_batch] y = y[:2*half_batch] - M = 64 # number of random vectors, this should be more than enough. - r = torch.randn(half_batch, M, dim, device=x.device).repeat_interleave(2, dim=0) - # r: (batch_size, M, dim) - r = torch.matmul(x, r.transpose(1, 2)) # (batch_size, seq_len, m) - r = torch.matmul(r.transpose(1, 2), y) # (batch_size, m, dim) - r = r * 1. / (seq_len * (num_channels ** 0.5)) - # the summed-over dims in matmuls were num_channels and seq_len but the channel dims can be - # treated as independent not correlated so power of 0.5 - - # correlation between tr(M) estimates between elements of the batch. - correlation = r[0::2] * r[1::2] + + x1, x2 = x[0::2], x[1::2] + y1, y2 = x[0::2], x[1::2] + + S1 = torch.matmul(x1.reshape(-1, num_channels).t(), y1.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) + S2 = torch.matmul(x2.reshape(-1, num_channels).t(), y2.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) + + # S1, S2: (num_channels, num_channels) + correlation = S1 * S2 if random.random() < 0.0001: logging.info( From de6162116642303858583dca0d7708a12a16f444 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Dec 2025 22:42:09 +0800 Subject: [PATCH 0788/1191] Add limit=0.03 in CorrelationLimiter. --- egs/librispeech/ASR/zipformer/scaling.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8b400f9fe5..d01d52c0c5 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1593,9 +1593,10 @@ def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: class CorrelationLimiterFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Tensor], name: str): + def forward(ctx, x: Tensor, y: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): ctx.save_for_backward(x, y) ctx.mask = mask + ctx.limit = limit ctx.aux_loss_scale = aux_loss_scale ctx.name = name return x @@ -1642,14 +1643,15 @@ def norm(x: Tensor): S2 = torch.matmul(x2.reshape(-1, num_channels).t(), y2.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) # S1, S2: (num_channels, num_channels) - correlation = S1 * S2 + correlation = (S1 * S2).mean() + loss = (correlation - ctx.limit).relu() if random.random() < 0.0001: logging.info( - f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation.mean()}" + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, loss={loss}" ) - correlation.backward(gradient=torch.full_like(correlation, aux_loss_scale * batch_size * seq_len / correlation.numel())) + loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) return x_orig.grad, y_orig.grad, None, None, None @@ -1660,9 +1662,10 @@ class CorrelationLimiter(torch.nn.Module): Adds a penalty in backprop if feature x and feature y are correlated. Assumes input is (batch, seq, channel) """ - def __init__(self): + def __init__(self, limit: FloatLike = 0.03): super().__init__() self.name = None + self.limit = limit def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: @@ -1674,7 +1677,7 @@ def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Te return torch.tensor(0.0, device=x.device) else: return CorrelationLimiterFunction.apply(x, y, - aux_loss_scale, mask, + aux_loss_scale, limit, mask, self.name) From dabdc0b85070032c7a191d55b9d0ebaeaff56748 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Dec 2025 22:42:49 +0800 Subject: [PATCH 0789/1191] Add limit=0.03 in CorrelationLimiter. --- egs/librispeech/ASR/zipformer/scaling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d01d52c0c5..2e22b4f462 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1654,7 +1654,7 @@ def norm(x: Tensor): loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) - return x_orig.grad, y_orig.grad, None, None, None + return x_orig.grad, y_orig.grad, None, None, None, None class CorrelationLimiter(torch.nn.Module): @@ -1677,7 +1677,9 @@ def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Te return torch.tensor(0.0, device=x.device) else: return CorrelationLimiterFunction.apply(x, y, - aux_loss_scale, limit, mask, + aux_loss_scale, + float(self.limit), + mask, self.name) From ba8a31af7470244cc676801163fcb6b2515735e9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Dec 2025 23:31:34 +0800 Subject: [PATCH 0790/1191] Fix bug that used y instead of x; restore limit to 0.0. --- egs/librispeech/ASR/zipformer/scaling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2e22b4f462..aa101fb51b 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,9 +1635,8 @@ def norm(x: Tensor): x = x[:2*half_batch] y = y[:2*half_batch] - x1, x2 = x[0::2], x[1::2] - y1, y2 = x[0::2], x[1::2] + y1, y2 = y[0::2], y[1::2] S1 = torch.matmul(x1.reshape(-1, num_channels).t(), y1.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) S2 = torch.matmul(x2.reshape(-1, num_channels).t(), y2.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) @@ -1662,7 +1661,7 @@ class CorrelationLimiter(torch.nn.Module): Adds a penalty in backprop if feature x and feature y are correlated. Assumes input is (batch, seq, channel) """ - def __init__(self, limit: FloatLike = 0.03): + def __init__(self, limit: FloatLike = 0.0): super().__init__() self.name = None self.limit = limit From b249b01b3cc7d4f43764b4b50d05b6eb2569470c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 Dec 2025 23:51:52 +0800 Subject: [PATCH 0791/1191] Set threshold to 0.005. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index aa101fb51b..7708c6f98c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1661,7 +1661,7 @@ class CorrelationLimiter(torch.nn.Module): Adds a penalty in backprop if feature x and feature y are correlated. Assumes input is (batch, seq, channel) """ - def __init__(self, limit: FloatLike = 0.0): + def __init__(self, limit: FloatLike = 0.005): super().__init__() self.name = None self.limit = limit From 56a1d2d2aaf642503b4e33b8d20b2fbbcfb4d7e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 25 Dec 2025 00:03:42 +0800 Subject: [PATCH 0792/1191] Impose the correlation limit of 0.03 on xx, yy and xy randomly. --- egs/librispeech/ASR/zipformer/scaling.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7708c6f98c..94311f6e45 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,11 +1635,18 @@ def norm(x: Tensor): x = x[:2*half_batch] y = y[:2*half_batch] - x1, x2 = x[0::2], x[1::2] - y1, y2 = y[0::2], y[1::2] - S1 = torch.matmul(x1.reshape(-1, num_channels).t(), y1.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) - S2 = torch.matmul(x2.reshape(-1, num_channels).t(), y2.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) + r = torch.rand(2, device=x.device) < 0.5 + + # a is x or y; b is (independently) x or y. + a = torch.where(r[0], x, y) + b = torch.where(r[1], x, y) + + a1, a2 = a[0::2], a[1::2] + b1, b2 = b[0::2], b[1::2] + + S1 = torch.matmul(a1.reshape(-1, num_channels).t(), b1.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) + S2 = torch.matmul(a2.reshape(-1, num_channels).t(), b2.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) # S1, S2: (num_channels, num_channels) correlation = (S1 * S2).mean() @@ -1661,7 +1668,7 @@ class CorrelationLimiter(torch.nn.Module): Adds a penalty in backprop if feature x and feature y are correlated. Assumes input is (batch, seq, channel) """ - def __init__(self, limit: FloatLike = 0.005): + def __init__(self, limit: FloatLike = 0.03): super().__init__() self.name = None self.limit = limit From 555741d608d73cfb226168f17680af2d85a42f04 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 25 Dec 2025 22:13:11 +0800 Subject: [PATCH 0793/1191] Make correlation be computed on appended x and y. --- egs/librispeech/ASR/zipformer/scaling.py | 22 ++++++++++------------ egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 94311f6e45..bdfcdae488 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1632,21 +1632,19 @@ def norm(x: Tensor): # where they may really be duplicates return None, None, None, None, None - x = x[:2*half_batch] - y = y[:2*half_batch] + x = torch.cat((x, y), dim=-1) + x1, x2 = x[0::2], x[1::2] - r = torch.rand(2, device=x.device) < 0.5 - - # a is x or y; b is (independently) x or y. - a = torch.where(r[0], x, y) - b = torch.where(r[1], x, y) - - a1, a2 = a[0::2], a[1::2] - b1, b2 = b[0::2], b[1::2] + if mask is not None: + numel1 = mask[0::2].sum() + numel2 = mask[1::2].sum() + else: + numel1 = x1.shape[0] * x1.shape[1] + numel2 = x2.shape[0] * x2.shape[1] - S1 = torch.matmul(a1.reshape(-1, num_channels).t(), b1.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) - S2 = torch.matmul(a2.reshape(-1, num_channels).t(), b2.reshape(-1, num_channels)) * (1. / (half_batch * seq_len)) + S1 = torch.matmul(x1.reshape(-1, num_channels).t(), x1.reshape(-1, num_channels)) * (1. / numel1) + S2 = torch.matmul(x2.reshape(-1, num_channels).t(), x2.reshape(-1, num_channels)) * (1. / numel2) # S1, S2: (num_channels, num_channels) correlation = (S1 * S2).mean() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4a847641c7..6800a6709d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -513,7 +513,8 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) - self.offset_correlation_limiter = CorrelationLimiter() + power = 0.6 + self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / ((2 * embed_dim) ** power))) self.self_attn_weights = MultiheadAttentionWeights( embed_dim, From f8b4c4e8484e09ed90983c43aa07da7dc6974c63 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 25 Dec 2025 22:35:12 +0800 Subject: [PATCH 0794/1191] This should make no diff to results but should be a little faster. --- egs/librispeech/ASR/zipformer/scaling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index bdfcdae488..7aa7c9b2b1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,16 +1635,18 @@ def norm(x: Tensor): x = torch.cat((x, y), dim=-1) x1, x2 = x[0::2], x[1::2] + x1 = x1.reshape(-1, num_channels) + x2 = x2.reshape(-1, num_channels) if mask is not None: numel1 = mask[0::2].sum() numel2 = mask[1::2].sum() else: - numel1 = x1.shape[0] * x1.shape[1] - numel2 = x2.shape[0] * x2.shape[1] + numel1 = x1.shape[0] + numel2 = x2.shape[0] - S1 = torch.matmul(x1.reshape(-1, num_channels).t(), x1.reshape(-1, num_channels)) * (1. / numel1) - S2 = torch.matmul(x2.reshape(-1, num_channels).t(), x2.reshape(-1, num_channels)) * (1. / numel2) + S1 = torch.matmul(x1.t(), x1) * (1. / numel1) + S2 = torch.matmul(x2.t(), x2) * (1. / numel2) # S1, S2: (num_channels, num_channels) correlation = (S1 * S2).mean() From f32a9c53d7bc8a7de3352e7d3a8b1eedaba1b818 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 25 Dec 2025 23:38:14 +0800 Subject: [PATCH 0795/1191] Reduce power from .6 to .45 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6800a6709d..92986edc39 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -513,7 +513,7 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) - power = 0.6 + power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / ((2 * embed_dim) ** power))) self.self_attn_weights = MultiheadAttentionWeights( From 301f866ec8286bbf155160514a6a0ba462dc4067 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 26 Dec 2025 11:45:52 +0800 Subject: [PATCH 0796/1191] fix to docs re --libri-copies option --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index bc77784f80..71585227c4 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -213,9 +213,8 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--libri-copies", type=int, default=1, - help="If set to <= 0, we use only librispeech (CAUTION: this may be surprising). If set to > 0, every epoch means one epoch " - "of gigaspeech and libri_copies epochs of librispeech (although it is really libri_copies times 3, because of Librispeech " - "using speed augmentation." + help="The number of copies of librispeech data used per epoch, i.e. per epoch of gigaspeech, if --use-giga=True." + "(it is really libri_copies times 3, because of Librispeech using speed augmentation)." ) parser.add_argument( From a31368d680b52ecd013500163d6fa2f1c076bacc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 27 Dec 2025 23:23:29 +0800 Subject: [PATCH 0797/1191] Fix bug regarding num channels, increasing dim; reduce power to 0.4; randomly sub-select from (2*channels) if >512. --- egs/librispeech/ASR/zipformer/scaling.py | 26 ++++++++++++++++++++-- egs/librispeech/ASR/zipformer/zipformer.py | 13 ++++++----- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7aa7c9b2b1..7352c33bf9 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1607,6 +1607,7 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 mask = ctx.mask aux_loss_scale = ctx.aux_loss_scale (batch_size, seq_len, num_channels) = x.shape + with torch.enable_grad(): with torch.amp.autocast('cuda', enabled=False): x, y = x.to(torch.float), y.to(torch.float) @@ -1615,6 +1616,10 @@ def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 y.requires_grad = True x_orig, y_orig = x, y + + y = y - x + # make y be the offset of the new src from the old src; do this in here to save memory and to slightly simplify the interface. + def norm(x: Tensor): eps = 1.0e-20 return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() @@ -1634,9 +1639,10 @@ def norm(x: Tensor): x = torch.cat((x, y), dim=-1) + C = x.shape[-1] # 2 * num_channels x1, x2 = x[0::2], x[1::2] - x1 = x1.reshape(-1, num_channels) - x2 = x2.reshape(-1, num_channels) + x1 = x1.reshape(-1, C) + x2 = x2.reshape(-1, C) if mask is not None: numel1 = mask[0::2].sum() @@ -1645,6 +1651,13 @@ def norm(x: Tensor): numel1 = x1.shape[0] numel2 = x2.shape[0] + max_channels = 512 # randomly select a subset of dims if C is more than this, for efficiency + if C > max_channels: + indexes = torch.rand(2, C, device=x.device).sort(dim=1)[1] # indexes: (2, C), type int64 + indexes = indexes[:, :max_channels] # (2, max_channels) + x1 = x1.index_select(dim=2, index=indexes[0]) + x2 = x2.index_select(dim=2, index=indexes[1]) + S1 = torch.matmul(x1.t(), x1) * (1. / numel1) S2 = torch.matmul(x2.t(), x2) * (1. / numel2) @@ -1679,6 +1692,15 @@ def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Te # returns a scalar tensor that should be included in the loss function with: # z = with_loss(z, ret, None) # where z is any quantity that will be used in calculating the main loss. + # expected to be called as something like: + # src_orig = src + # src = src + f(src) + # src = src + g(src) + # .. + # src = with_loss(src, self.correlation_limiter(src_orig.permute(1, 0, 2), src.permute(1, 0, 2), + # aux_loss_scale, src_key_padding_mask), + # None) + # (assuming a (seq, batch, channel) layout; this class expects (batch, seq, channel). if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return torch.tensor(0.0, device=x.device) else: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b6e04f46a9..c048d624c5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -513,7 +513,7 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) - power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / ((2 * embed_dim) ** power))) self.self_attn_weights = MultiheadAttentionWeights( @@ -583,13 +583,14 @@ def forward( offset = self.offset_scale_limiter(offset, aux_loss_scale) - offset = with_loss(offset, - self.offset_correlation_limiter( - src_orig.permute(1, 0, 2), offset.permute(1, 0, 2), - 2. * aux_loss_scale, mask=src_key_padding_mask)) - src = src_orig + offset + src = with_loss(src, + self.offset_correlation_limiter( + src_orig.permute(1, 0, 2), src.permute(1, 0, 2), + 2. * aux_loss_scale, mask=src_key_padding_mask)) + + src = self.norm(src) return src From 7f82d1fb7d6b7e2dea15d4076b61b68f5a1ab0c0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 27 Dec 2025 23:47:18 +0800 Subject: [PATCH 0798/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7352c33bf9..b0ea4de2fb 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1655,11 +1655,16 @@ def norm(x: Tensor): if C > max_channels: indexes = torch.rand(2, C, device=x.device).sort(dim=1)[1] # indexes: (2, C), type int64 indexes = indexes[:, :max_channels] # (2, max_channels) - x1 = x1.index_select(dim=2, index=indexes[0]) - x2 = x2.index_select(dim=2, index=indexes[1]) + x1a = x1.index_select(dim=1, index=indexes[0]) + x1b = x1.index_select(dim=1, index=indexes[1]) + x2a = x2.index_select(dim=1, index=indexes[0]) + x2b = x2.index_select(dim=1, index=indexes[1]) + else: + x1a, x1b = x1, x1 + x2a, x2b = x2, x2 - S1 = torch.matmul(x1.t(), x1) * (1. / numel1) - S2 = torch.matmul(x2.t(), x2) * (1. / numel2) + S1 = torch.matmul(x1a.t(), x1b) * (1. / numel1) + S2 = torch.matmul(x2a.t(), x2b) * (1. / numel2) # S1, S2: (num_channels, num_channels) correlation = (S1 * S2).mean() From a43a22bc97423e60d43a188aa717284120beb011 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 28 Dec 2025 00:14:41 +0800 Subject: [PATCH 0799/1191] revert limit power from .4 to .45 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index c048d624c5..043c1df349 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -513,7 +513,7 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) - power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / ((2 * embed_dim) ** power))) self.self_attn_weights = MultiheadAttentionWeights( From 2952801cb92555c7da9a954279cf397462c86230 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 28 Dec 2025 10:19:20 +0800 Subject: [PATCH 0800/1191] Simplify CorrelationLimiter to be on just x (src_orig), not cat(x,y). --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b0ea4de2fb..81bf54b5c2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1638,8 +1638,8 @@ def norm(x: Tensor): return None, None, None, None, None - x = torch.cat((x, y), dim=-1) - C = x.shape[-1] # 2 * num_channels + #x = torch.cat((x, y), dim=-1) + C = x.shape[-1] # num_channels x1, x2 = x[0::2], x[1::2] x1 = x1.reshape(-1, C) x2 = x2.reshape(-1, C) @@ -1666,7 +1666,7 @@ def norm(x: Tensor): S1 = torch.matmul(x1a.t(), x1b) * (1. / numel1) S2 = torch.matmul(x2a.t(), x2b) * (1. / numel2) - # S1, S2: (num_channels, num_channels) + # S1, S2: (N, N) where N = min(num_channels, max_channels) correlation = (S1 * S2).mean() loss = (correlation - ctx.limit).relu() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 043c1df349..a91b83cf4d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -514,7 +514,7 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) - self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / ((2 * embed_dim) ** power))) + self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn_weights = MultiheadAttentionWeights( embed_dim, From bb7c5addd2094acf03208addd7fdc655fab28e93 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 2 Jan 2026 15:37:28 +0800 Subject: [PATCH 0801/1191] Simplify correlation limiter code, should make no real difference (except removed random subsetting of dimensions.) --- egs/librispeech/ASR/zipformer/scaling.py | 60 +++++++--------------- egs/librispeech/ASR/zipformer/zipformer.py | 13 +++-- 2 files changed, 24 insertions(+), 49 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 81bf54b5c2..92fbd0ce21 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1593,8 +1593,8 @@ def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: class CorrelationLimiterFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, y: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): - ctx.save_for_backward(x, y) + def forward(ctx, x: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): + ctx.save_for_backward(x) ctx.mask = mask ctx.limit = limit ctx.aux_loss_scale = aux_loss_scale @@ -1603,33 +1603,26 @@ def forward(ctx, x: Tensor, y: Tensor, aux_loss_scale: float, limit: float, mask @staticmethod def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 - x, y = ctx.saved_tensors + x, = ctx.saved_tensors mask = ctx.mask aux_loss_scale = ctx.aux_loss_scale (batch_size, seq_len, num_channels) = x.shape with torch.enable_grad(): with torch.amp.autocast('cuda', enabled=False): - x, y = x.to(torch.float), y.to(torch.float) - x, y = x.detach(), y.detach() + x = x.to(torch.float) + x = x.detach() x.requires_grad = True - y.requires_grad = True - x_orig, y_orig = x, y - - - y = y - x - # make y be the offset of the new src from the old src; do this in here to save memory and to slightly simplify the interface. + x_orig = x def norm(x: Tensor): eps = 1.0e-20 return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() x = norm(x) - y = norm(y) if mask is not None: mask = (~mask).to(x.dtype).unsqueeze(-1) x = x * mask - y = y * mask half_batch = batch_size // 2 if half_batch <= 1: @@ -1651,20 +1644,8 @@ def norm(x: Tensor): numel1 = x1.shape[0] numel2 = x2.shape[0] - max_channels = 512 # randomly select a subset of dims if C is more than this, for efficiency - if C > max_channels: - indexes = torch.rand(2, C, device=x.device).sort(dim=1)[1] # indexes: (2, C), type int64 - indexes = indexes[:, :max_channels] # (2, max_channels) - x1a = x1.index_select(dim=1, index=indexes[0]) - x1b = x1.index_select(dim=1, index=indexes[1]) - x2a = x2.index_select(dim=1, index=indexes[0]) - x2b = x2.index_select(dim=1, index=indexes[1]) - else: - x1a, x1b = x1, x1 - x2a, x2b = x2, x2 - - S1 = torch.matmul(x1a.t(), x1b) * (1. / numel1) - S2 = torch.matmul(x2a.t(), x2b) * (1. / numel2) + S1 = torch.matmul(x1.t(), x1) * (1. / numel1) + S2 = torch.matmul(x2.t(), x2) * (1. / numel2) # S1, S2: (N, N) where N = min(num_channels, max_channels) correlation = (S1 * S2).mean() @@ -1678,13 +1659,17 @@ def norm(x: Tensor): loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) - return x_orig.grad, y_orig.grad, None, None, None, None + return x_orig.grad, None, None, None, None class CorrelationLimiter(torch.nn.Module): """ - Adds a penalty in backprop if feature x and feature y are correlated. - Assumes input is (batch, seq, channel) + Adds a penalty in backprop if the input feature has a covariance matrix that is + too different from the identity matrix. limit=1/num_channels is the + smallest limit you can provide but the limit should be much larger than + this, like 1/sqrt(num_channels). + + Assumes input is (batch, seq, channel) """ def __init__(self, limit: FloatLike = 0.03): super().__init__() @@ -1692,24 +1677,15 @@ def __init__(self, limit: FloatLike = 0.03): self.limit = limit - def forward(self, x: Tensor, y: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: - # x and y should both be: (batch, seq, channel) + def forward(self, x: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: + # x should be: (batch, seq, channel) # returns a scalar tensor that should be included in the loss function with: # z = with_loss(z, ret, None) # where z is any quantity that will be used in calculating the main loss. - # expected to be called as something like: - # src_orig = src - # src = src + f(src) - # src = src + g(src) - # .. - # src = with_loss(src, self.correlation_limiter(src_orig.permute(1, 0, 2), src.permute(1, 0, 2), - # aux_loss_scale, src_key_padding_mask), - # None) - # (assuming a (seq, batch, channel) layout; this class expects (batch, seq, channel). if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return torch.tensor(0.0, device=x.device) else: - return CorrelationLimiterFunction.apply(x, y, + return CorrelationLimiterFunction.apply(x, aux_loss_scale, float(self.limit), mask, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a91b83cf4d..7e4254b616 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -514,7 +514,7 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) - self.offset_correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) + self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn_weights = MultiheadAttentionWeights( embed_dim, @@ -562,6 +562,11 @@ def forward( """ src_orig = src + src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + 2. * aux_loss_scale, mask=src_key_padding_mask), + None) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( src, @@ -585,12 +590,6 @@ def forward( src = src_orig + offset - src = with_loss(src, - self.offset_correlation_limiter( - src_orig.permute(1, 0, 2), src.permute(1, 0, 2), - 2. * aux_loss_scale, mask=src_key_padding_mask)) - - src = self.norm(src) return src From b4a8017af259ffa059b77d0b35cf4e630e8158f6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 4 Jan 2026 15:47:20 +0800 Subject: [PATCH 0802/1191] Change initialization, and add penalty before sigmoid, of GatedSelfAttention --- egs/librispeech/ASR/zipformer/zipformer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7e4254b616..6522a41d95 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1044,9 +1044,6 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) - self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) - key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim) * num_heads @@ -1319,13 +1316,15 @@ def __init__( value_head_dim: int, ) -> None: super().__init__() - self.in_proj = ScaledLinear(embed_dim, 2 *num_heads * value_head_dim, - bias=True) + self.in_proj = ScaledLinear(embed_dim, 2 * num_heads * value_head_dim, + initial_scale=0.1, bias=True) + + self.copy_x = nn.Identity() # diagnostics. self.sigmoid = nn.Sigmoid() self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 ) f = max(1.0, embed_dim / (num_heads * value_head_dim)) @@ -1357,6 +1356,7 @@ def forward( x = self.in_proj(x) # (seq_len, batch_size, 2 * num_heads * value_head_dim) x, s = x.chunk(2, dim=-1) x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + x = self.copy_x(x) # now x: (num_heads, batch_size, seq_len, value_head_dim) value_head_dim = x.shape[-1] @@ -1370,8 +1370,13 @@ def forward( .view(seq_len, batch_size, num_heads * value_head_dim) ) + + if self.training: + # don't let the sigmoid values get too extreme, limit to -2..2. + s = penalize_abs_values_gt(s, 2, penalty=0.02*aux_loss_scale) + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = x * self.sigmoid(s) + x = self.sigmoid(s) x = self.out_proj(x) return x From 888ba10f8ebada24a56c9c823bba5bd38b1a2cd4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 4 Jan 2026 15:57:50 +0800 Subject: [PATCH 0803/1191] Bug fix in self-attn module --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6522a41d95..9add9da625 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1376,7 +1376,7 @@ def forward( s = penalize_abs_values_gt(s, 2, penalty=0.02*aux_loss_scale) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.sigmoid(s) + x = x * self.sigmoid(s) x = self.out_proj(x) return x From c05321291ffd872ccf382db9b0d19a48948da5f0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 6 Jan 2026 15:00:27 +0800 Subject: [PATCH 0804/1191] Do nonlinear interpolation in FftConv, with low_freq_factor=0.25. --- egs/librispeech/ASR/zipformer/zipformer.py | 57 ++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 9add9da625..6eba70c8dc 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1477,17 +1477,67 @@ def round_up_to_power_of_two(x): +def interpolate_warped(x: Tensor, + freqs_out: int, + low_freq_factor: float, + dim: int): + + """ + Interpolates between the elements of x, similar to x.index_select(dim, ...), but with interpolation. + Args: + x: arbitrary shaped Tensor except that its dimension "dim" will be interpreted as representing + warped frequencies, with the lowest index correponding to frequency 0 and the highest index + corresponding to the nyquist frequency pi, but the frequencies near 0 closer together according + to low_freq_factor. + freqs_out: an integer giving the number of frequencies which we want to interpolate x, with the + 0 and freqs_out-1 representing 0 and respectively. + low_freq_factor: a float 0 < low_freq_factor < 1, e.g. if it is 0.1 then low-numbered frequency + indexes in x will be about 10 times closer together. + + + Returns: + a Tensor with the same shape as x, except dimension "dim" will be of size equal to freqs_out. + Its elements will be interpolated between elements of x. + """ + num_freqs_in = x.shape[dim] + + # note: the factor of math.pi should in principle appear in both freqs_in + # and freqs_out but we omit it from both; this will have no effect on the + # result. + + log_freqs_in = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs_in, device=x.device) + freqs_in = log_freqs_in.exp() - low_freq_factor # these range from 0 to 1. + freqs_out = torch.linspace(0.0, 1.0, freqs_out, device=x.device) # the frequencies of the discrete fourier basis. + + indexes = torch.searchsorted(freqs_in, freqs_out) + indexes = indexes.clamp(min=1, max=num_freqs_in - 1) + indexes1 = indexes - 1 + lower_freq = freqs_in[indexes1] + upper_freq = freqs_in[indexes] + upper_weight = (freqs_out - lower_freq) / (upper_freq - lower_freq) + lower_weight = 1. - upper_weight + + if dim < 0: + dim += x.ndim + for _ in range(dim, x.ndim - 1): + lower_weight = lower_weight.unsqueeze(-1) + upper_weight = upper_weight.unsqueeze(-1) + return lower_weight * x.index_select(dim, indexes1) + upper_weight * x.index_select(dim, indexes) + + + class FftConv(nn.Module): def __init__(self, num_channels: int, params_per_channel: int, + low_freq_factor: float = 0.25, # factor of how far apart specified freqs are on the low end vs the high end bias: bool = True): super().__init__() self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) # one factor of 2 is for (sin, cos); the other is to double the num representable freqs self.weight_proj = nn.Linear(params_per_channel, 4 * params_per_channel) - + self.low_freq_factor = low_freq_factor if bias: self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) @@ -1502,8 +1552,9 @@ def forward(self, x = torch.fft.rfft(x.to(torch.float32), dim=0) # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs - weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) - weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) + weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) + weight = interpolate_warped(weight, N, self.low_freq_factor, dim=2) + #weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) # weight: (N, num_channels) weight = weight.unsqueeze(1) # (N, 1, num_channels) From f8b356bc1df9fd8c7793ef9b699a0bc675d8ac67 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 6 Jan 2026 17:19:46 +0800 Subject: [PATCH 0805/1191] Add pos_scores to self attention. --- egs/librispeech/ASR/zipformer/zipformer.py | 103 ++++++++++++++++++++- 1 file changed, 100 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6eba70c8dc..3a0779d290 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1035,6 +1035,7 @@ def __init__( embed_dim: int, num_heads: int, query_head_dim: int, + pos_dim: int = 4, dropout: float = 0.0, ) -> None: super().__init__() @@ -1045,7 +1046,7 @@ def __init__( self.name = None # will be overwritten in training code; for diagnostics. key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim) * num_heads + in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5 that has been used in previous forms of attention, @@ -1059,7 +1060,10 @@ def __init__( self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 + self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=64, low_freq_factor=0.2) + self.copy_query = Identity() + self.copy_pos_query = Identity() def forward( self, @@ -1091,11 +1095,14 @@ def forward( # self-attention q = x[..., 0:query_dim] k = x[..., query_dim : 2 * query_dim] + p = x[..., 2 * query_dim:] q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, -1) q = self.rope(q.permute(1, 0, 2, 3)) # (batch, seq, head, channel) k = self.rope(k.permute(1, 0, 2, 3)) # (batch, seq, head, channel) @@ -1104,7 +1111,11 @@ def forward( q = q.permute(2, 0, 1, 3) # (head, batch, time1, query_head_dim) k = k.permute(2, 0, 3, 1) # (head, batch, d_k, time2) - attn_scores = torch.matmul(q, k) + attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) + + p = p.permute(1, 2, 0, 3) + pos_scores = self.rel_pos(p) # (batch, head, time1, time2) + attn_scores = attn_scores + pos_scores.permute(1, 0, 2, 3) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) @@ -1481,7 +1492,6 @@ def interpolate_warped(x: Tensor, freqs_out: int, low_freq_factor: float, dim: int): - """ Interpolates between the elements of x, similar to x.index_select(dim, ...), but with interpolation. Args: @@ -1524,6 +1534,93 @@ def interpolate_warped(x: Tensor, upper_weight = upper_weight.unsqueeze(-1) return lower_weight * x.index_select(dim, indexes1) + upper_weight * x.index_select(dim, indexes) +class RelPosScores(nn.Module): + def __init__(self, + num_heads: int, + pos_dim: int, + num_freqs: int, + low_freq_factor: float): + super().__init__() + self.params = nn.Parameter(0.2 * torch.randn(num_heads, pos_dim * 2, num_freqs)) + self.num_freqs = num_freqs + self.low_freq_factor = low_freq_factor + + def forward(self, p: Tensor) -> Tensor: + """ + Compute and return unnormalized log scores for relative position. + Args: + p: these are the position-queries, of shape (batch_size, num_heads, seq_len, pos_dim) + (they are obtained via projection, just like the queries). + Returns: + scores: (batch_size, num_heads, dest_seq_len, src_seq_len), + + where dest_seq_len and src_seq_len are numerically equal to seq_len but dest_seq_len relates to the + query and src_seq_len to the key. + """ + + (batch_size, num_heads, seq_len, pos_dim) = p.shape + + + # making "factor" more than 1 is to ensure there is plenty of "extra" + # room in this sequence length past seq_len so it's similar to what we'd + # get with infinite sequence length. there will be another factor of 2 + # because S is half the sequence length we use for the FFT + factor = 4 + S = round_up_to_power_of_two(factor * seq_len) + F = S + 1 # the number of frequencies in the FFT, including the nyquist. + + # self.params: (num_heads, pos_dim * 2, num_freqs) + X = interpolate_warped(self.params, F, self.low_freq_factor, dim=2) + + ones = torch.cat([torch.ones(S, device=p.device), torch.zeros(S, device=p.device)]) + + # X: (num_heads, pos_dim * 2, F) + X = torch.view_as_complex(X.reshape(num_heads, pos_dim, 2, F).permute(0, 1, 3, 2).contiguous()) + + Ones = torch.fft.rfft(ones, dim=0) + X = X * Ones + + # X: (num_heads, pos_dim, F); complex. + x = torch.fft.irfft(X, n=2*S, dim=2) + # x: (num_heads, pos_dim * 2, 2 * S) + + x = x.roll(S, dims=2) + # x: (num_heads, pos_dim * 2, 2 * S); now the position of offset=0 is at position S rather than position + # zero. + + x = x[:, :, S - (seq_len - 1) : S + seq_len] + assert x.shape == (num_heads, pos_dim, 2 * seq_len - 1) + + + # with seq_len2 = 2 * seq_len - 1, + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + pos_weights = torch.matmul(p, x) + + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. This is all copied from our old conformer/zipformer code. + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (batch_size, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + return pos_weights + From 662efa82dde17151a2ac510aa6766bc1c255d63e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 6 Jan 2026 22:51:55 +0800 Subject: [PATCH 0806/1191] A different implementation of pos_scores, with a base_dim and sinc^2 envelopes, no FFT involved. --- egs/librispeech/ASR/zipformer/zipformer.py | 63 ++++++++++++---------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3a0779d290..bc5587977a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1060,7 +1060,7 @@ def __init__( self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 - self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=64, low_freq_factor=0.2) + self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=128, base=10_000) self.copy_query = Identity() self.copy_pos_query = Identity() @@ -1539,11 +1539,19 @@ def __init__(self, num_heads: int, pos_dim: int, num_freqs: int, - low_freq_factor: float): + base: float = 10_000): + """ + Implementation of relative position scores with mathematically sensible sinc envelope. + """ super().__init__() - self.params = nn.Parameter(0.2 * torch.randn(num_heads, pos_dim * 2, num_freqs)) + self.base = base + self.weight = nn.Parameter(0.2 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) + #n = (num_freqs//8) + #with torch.no_grad(): + # self.weight[..., :n] = 0. + # self.weight[..., n+1:] = 0. + self.num_freqs = num_freqs - self.low_freq_factor = low_freq_factor def forward(self, p: Tensor) -> Tensor: """ @@ -1560,37 +1568,39 @@ def forward(self, p: Tensor) -> Tensor: (batch_size, num_heads, seq_len, pos_dim) = p.shape + num_freqs = self.num_freqs - # making "factor" more than 1 is to ensure there is plenty of "extra" - # room in this sequence length past seq_len so it's similar to what we'd - # get with infinite sequence length. there will be another factor of 2 - # because S is half the sequence length we use for the FFT - factor = 4 - S = round_up_to_power_of_two(factor * seq_len) - F = S + 1 # the number of frequencies in the FFT, including the nyquist. + freqs = math.pi * torch.linspace(0., -math.log(self.base), num_freqs + 1, device=p.device).exp()[1:] # base freqs. - # self.params: (num_heads, pos_dim * 2, num_freqs) - X = interpolate_warped(self.params, F, self.low_freq_factor, dim=2) + triangle_size = (self.base ** (1. / (num_freqs + 1))) - 1 + # e.g. 0.15. triangle_size is the relative separation between frequencies. - ones = torch.cat([torch.ones(S, device=p.device), torch.zeros(S, device=p.device)]) + # e.g. if base ** (1/(num_freqs+1)) == 1.15, then triangle_size == 0.15. + # see + # https://www.physicsforums.com/threads/fourier-transform-triangular-pulse.850993/, + # it corresponds to b. we want sinc(\omega b / 2) where b == + # freqs[i] * triangle_size, with the version of sinc without 2pi; and omega here was + # frequency, but actually it's really time as the triangle was the envelope in + # fourier space and our FFT is actually the inverse FFT. - # X: (num_heads, pos_dim * 2, F) - X = torch.view_as_complex(X.reshape(num_heads, pos_dim, 2, F).permute(0, 1, 3, 2).contiguous()) - Ones = torch.fft.rfft(ones, dim=0) - X = X * Ones + t = torch.arange(-(seq_len - 1), seq_len, device=p.device) - # X: (num_heads, pos_dim, F); complex. - x = torch.fft.irfft(X, n=2*S, dim=2) - # x: (num_heads, pos_dim * 2, 2 * S) + print("freqs = ", freqs) - x = x.roll(S, dims=2) - # x: (num_heads, pos_dim * 2, 2 * S); now the position of offset=0 is at position S rather than position - # zero. + angles = t.unsqueeze(-1) * freqs # (2*seq_len - 1, num_freqs) - x = x[:, :, S - (seq_len - 1) : S + seq_len] - assert x.shape == (num_heads, pos_dim, 2 * seq_len - 1) + def sinc2(x): + return torch.where(x == 0.0, torch.ones_like(x), x.sin() / x) ** 2 + + envelope = sinc2(angles * (0.5 * triangle_size)) + cos = angles.cos() * envelope + sin = angles.sin() * envelope + basis = torch.cat((cos, sin), dim=1) # (2 * seq_len - 1, 2 * num_freqs) + + x = torch.matmul(self.weight, basis.t()) + assert x.shape == (num_heads, pos_dim, 2 * seq_len - 1) # with seq_len2 = 2 * seq_len - 1, # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) @@ -1623,7 +1633,6 @@ def forward(self, p: Tensor) -> Tensor: - class FftConv(nn.Module): def __init__(self, num_channels: int, From f42b7e4625e89b8ecc81a153929f40da4dce89d4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 6 Jan 2026 23:41:57 +0800 Subject: [PATCH 0807/1191] Revert 1889 so baseline is 1888. --- egs/librispeech/ASR/zipformer/zipformer.py | 55 +--------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bc5587977a..a3fee9d8e7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1486,54 +1486,6 @@ def round_up_to_power_of_two(x): x = x + 1 return x - - -def interpolate_warped(x: Tensor, - freqs_out: int, - low_freq_factor: float, - dim: int): - """ - Interpolates between the elements of x, similar to x.index_select(dim, ...), but with interpolation. - Args: - x: arbitrary shaped Tensor except that its dimension "dim" will be interpreted as representing - warped frequencies, with the lowest index correponding to frequency 0 and the highest index - corresponding to the nyquist frequency pi, but the frequencies near 0 closer together according - to low_freq_factor. - freqs_out: an integer giving the number of frequencies which we want to interpolate x, with the - 0 and freqs_out-1 representing 0 and respectively. - low_freq_factor: a float 0 < low_freq_factor < 1, e.g. if it is 0.1 then low-numbered frequency - indexes in x will be about 10 times closer together. - - - Returns: - a Tensor with the same shape as x, except dimension "dim" will be of size equal to freqs_out. - Its elements will be interpolated between elements of x. - """ - num_freqs_in = x.shape[dim] - - # note: the factor of math.pi should in principle appear in both freqs_in - # and freqs_out but we omit it from both; this will have no effect on the - # result. - - log_freqs_in = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs_in, device=x.device) - freqs_in = log_freqs_in.exp() - low_freq_factor # these range from 0 to 1. - freqs_out = torch.linspace(0.0, 1.0, freqs_out, device=x.device) # the frequencies of the discrete fourier basis. - - indexes = torch.searchsorted(freqs_in, freqs_out) - indexes = indexes.clamp(min=1, max=num_freqs_in - 1) - indexes1 = indexes - 1 - lower_freq = freqs_in[indexes1] - upper_freq = freqs_in[indexes] - upper_weight = (freqs_out - lower_freq) / (upper_freq - lower_freq) - lower_weight = 1. - upper_weight - - if dim < 0: - dim += x.ndim - for _ in range(dim, x.ndim - 1): - lower_weight = lower_weight.unsqueeze(-1) - upper_weight = upper_weight.unsqueeze(-1) - return lower_weight * x.index_select(dim, indexes1) + upper_weight * x.index_select(dim, indexes) - class RelPosScores(nn.Module): def __init__(self, num_heads: int, @@ -1586,8 +1538,6 @@ def forward(self, p: Tensor) -> Tensor: t = torch.arange(-(seq_len - 1), seq_len, device=p.device) - print("freqs = ", freqs) - angles = t.unsqueeze(-1) * freqs # (2*seq_len - 1, num_freqs) def sinc2(x): @@ -1637,13 +1587,11 @@ class FftConv(nn.Module): def __init__(self, num_channels: int, params_per_channel: int, - low_freq_factor: float = 0.25, # factor of how far apart specified freqs are on the low end vs the high end bias: bool = True): super().__init__() self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) # one factor of 2 is for (sin, cos); the other is to double the num representable freqs self.weight_proj = nn.Linear(params_per_channel, 4 * params_per_channel) - self.low_freq_factor = low_freq_factor if bias: self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) @@ -1659,8 +1607,7 @@ def forward(self, # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) - weight = interpolate_warped(weight, N, self.low_freq_factor, dim=2) - #weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) + weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) # weight: (N, num_channels) weight = weight.unsqueeze(1) # (N, 1, num_channels) From 52b98b003b7d63dc79d21271fb9c55d40ade4861 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 7 Jan 2026 22:48:28 +0800 Subject: [PATCH 0808/1191] Initialize weights of RelPosScores in a low pass way. --- egs/librispeech/ASR/zipformer/zipformer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a3fee9d8e7..ddc7bb9783 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1498,10 +1498,11 @@ def __init__(self, super().__init__() self.base = base self.weight = nn.Parameter(0.2 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) - #n = (num_freqs//8) - #with torch.no_grad(): - # self.weight[..., :n] = 0. - # self.weight[..., n+1:] = 0. + + with torch.no_grad(): + # initialize the weight in a low-pass way. + for _ in range(10): + self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) self.num_freqs = num_freqs From d6c4e8af0e2640578e2af71af5233f428ef7d75f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 8 Jan 2026 15:41:51 +0800 Subject: [PATCH 0809/1191] Remove rope --- egs/librispeech/ASR/zipformer/zipformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ddc7bb9783..b636938e04 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1058,7 +1058,7 @@ def __init__( bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 + #self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=128, base=10_000) @@ -1104,12 +1104,12 @@ def forward( k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) p = p.reshape(seq_len, batch_size, num_heads, -1) - q = self.rope(q.permute(1, 0, 2, 3)) # (batch, seq, head, channel) - k = self.rope(k.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + #q = self.rope(q.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + #k = self.rope(k.permute(1, 0, 2, 3)) # (batch, seq, head, channel) # time1 refers to target, time2 refers to source. - q = q.permute(2, 0, 1, 3) # (head, batch, time1, query_head_dim) - k = k.permute(2, 0, 3, 1) # (head, batch, d_k, time2) + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) From 07c432278cd5f84b34dec2e31ddede79edd6017d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 8 Jan 2026 23:36:56 +0800 Subject: [PATCH 0810/1191] Change RelPosScores to use asymmetric window, use 64 not 128 frequencies, have the freqs evenly spaced around 0. --- egs/librispeech/ASR/zipformer/zipformer.py | 145 ++++++++++++++++----- 1 file changed, 115 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b636938e04..cf1b571f9d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1060,7 +1060,7 @@ def __init__( #self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 - self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=128, base=10_000) + self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=64) self.copy_query = Identity() self.copy_pos_query = Identity() @@ -1486,25 +1486,128 @@ def round_up_to_power_of_two(x): x = x + 1 return x + +# wolfram alpha: +# the right part of the triangular bin, from 0 to +W. +# definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega +# = -(i W x + e^(-i W x) - 1)/(W x^2) +# Re[definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega] +# = (1 - cos(W x))/(W x^2) +# Im[definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega] +# = (sin(W x) - W x)/(W x^2) + +# the left part of the triangular bin, from -W to 0. +# definite integral from omega = -W to 0 of (omega/W + 1) exp(-i x \omega) d\omega +# (i W x - e^(i W x) + 1)/(W x^2) +# +# Let the center frequency be C. +# right side: +# = e^(i C x) * -(i W x + e^(-i W x) - 1)/(W x^2) +# "alternate form including W, C and x are real": [note, this is left hand width, W_l] +# (W x sin(C x) - cos(x (C - W)) + cos(C x))/(W x^2) - (i (sin(x (C - W)) + W x cos(C x) - sin(C x)))/(W x^2) +# +# left side: +# e^(i C x) * (i W x - e^(i W x) + 1)/(W x^2) +# "alternate form including W, C and x are real": [note, this is right hand width, W_r] +# -(W x sin(C x) + cos(x (C + W)) - cos(C x))/(W x^2) + (i (-sin(x (C + W)) + W x cos(C x) + sin(C x)))/(W x^2) +# +# summing the left and right sides: +# Real part: +# +# (W_r x sin(C x) - cos(x (C - W_r)) + cos(C x))/(W_r x^2) +# -(W_l x sin(C x) + cos(x (C + W_l)) - cos(C x))/(W_l x^2) +# = (cos(C x) - cos((C - W_r)x)) / W_r x^2 +# + (cos(C x) - cos((C + W_l)x)) / W_l x^2 + +# Imaginary part: +# -(sin(x (C - W_r)) + W_r x cos(C x) - sin(C x))) / (W_r x^2) +# +(-sin(x (C + W_l)) + W_l x cos(C x) + sin(C x)) / (W_l x^2) +# = ( sin(C x) - sin((C - W_r)x) ) / (W_r x^2) +# + ( sin(C x) - sin((C + W_l)x) ) / (W_l x^2) + +def compute_angular_freq_basis_triangular(freqs: Tensor, + t: Tensor, + scale: bool) -> Tensor: + """ + This function computes a set of windowed sinusoidal functions + corresponding to the real and imaginary parts of possibly-asymmetrical + triangular angular-frequency bins in frequency space. This basis + allows you to approximate functions whose fourier spectrum is + a piecewise linear function of frequency, with the x-axis values of + the inflection points of the piecewise linear function corresponding + to the supplied "freqs". + + Args: + freqs: the frequencies of the triangular-bin centers; the left and + right parts of the widths of the triangular bins correspond to the + distances to the two adjacent bins; for the "edge" bins, the + "edge" distances are duplicated. + t: the "t" (or x) values for which we want to evaluate the basis; this + will normally be some kind of arange expression e.g. arange(100). + scale: if True, the returned basis will contain the "natural" scaling + factors that arise from the bin widths; if False, it will be + normalized so that the maximum absolute value of the real + functions (attained at t==0) is 1. + + + Returns: + Returns the real and imaginary parts of the basis functions, with + shape (t.size(), freqs.size(), 2) + """ + dtype = freqs.dtype + freqs = freqs.to(torch.double) + t = t.to(torch.double) + + t = t.unsqueeze(-1) + + + C = freqs # Center frequencies of bins. + W = freqs[1:] - freqs[:-1] # the differences between the frequencies + W_l = torch.cat((W[:1], W)) # the difference between each center freq and the freq to the left + W_r = torch.cat((W, W[-1:])) # the difference between each center freq and the freq to the right + + angles = C * t + angles_r = (C - W_r) * t + angles_l = (C + W_l) * t + t2 = t**2 + scale_factor = 0.5 * (W_r + W_l) + + re = torch.where(t == 0., scale_factor, + (angles.cos() - angles_r.cos()) / (W_r * t2) + (angles.cos() - angles_l.cos()) / (W_l * t2)) + im = torch.where(t == 0., 0.0, + (angles.sin() - angles_r.sin()) / (W_r * t2) + (angles.sin() - angles_l.sin()) / (W_l * t2)) + + + if not scale: + re = re / scale_factor + im = im / scale_factor + + return torch.stack((re, im), dim=-1).to(dtype) + + + class RelPosScores(nn.Module): def __init__(self, num_heads: int, pos_dim: int, num_freqs: int, - base: float = 10_000): + low_freq_factor: float = 0.001): """ Implementation of relative position scores with mathematically sensible sinc envelope. """ super().__init__() - self.base = base self.weight = nn.Parameter(0.2 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) - with torch.no_grad(): # initialize the weight in a low-pass way. for _ in range(10): self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) - self.num_freqs = num_freqs + + log_freqs = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs) + freqs = math.pi * (log_freqs.exp() - low_freq_factor) # these range from 0 to pi. + freqs[0] = 0.0 # in case of roundoff (it should be 0, mathematically) + self.register_buffer('freqs', freqs, persistent=False) + def forward(self, p: Tensor) -> Tensor: """ @@ -1521,34 +1624,16 @@ def forward(self, p: Tensor) -> Tensor: (batch_size, num_heads, seq_len, pos_dim) = p.shape - num_freqs = self.num_freqs - - freqs = math.pi * torch.linspace(0., -math.log(self.base), num_freqs + 1, device=p.device).exp()[1:] # base freqs. - - triangle_size = (self.base ** (1. / (num_freqs + 1))) - 1 - # e.g. 0.15. triangle_size is the relative separation between frequencies. - - # e.g. if base ** (1/(num_freqs+1)) == 1.15, then triangle_size == 0.15. - # see - # https://www.physicsforums.com/threads/fourier-transform-triangular-pulse.850993/, - # it corresponds to b. we want sinc(\omega b / 2) where b == - # freqs[i] * triangle_size, with the version of sinc without 2pi; and omega here was - # frequency, but actually it's really time as the triangle was the envelope in - # fourier space and our FFT is actually the inverse FFT. + freqs = self.freqs # base freqs t = torch.arange(-(seq_len - 1), seq_len, device=p.device) - - angles = t.unsqueeze(-1) * freqs # (2*seq_len - 1, num_freqs) - - def sinc2(x): - return torch.where(x == 0.0, torch.ones_like(x), x.sin() / x) ** 2 - - envelope = sinc2(angles * (0.5 * triangle_size)) - cos = angles.cos() * envelope - sin = angles.sin() * envelope - - basis = torch.cat((cos, sin), dim=1) # (2 * seq_len - 1, 2 * num_freqs) + basis = compute_angular_freq_basis_triangular(freqs, t, scale=False) + # basis: (2 * seq_len - 1, num_freqs, 2) + basis = basis.permute(0, 2, 1) + # permute it because of how we did the low-pass initialization of weight, we want + # the cos and sin parts to each be continuous ranges, not interleaved. + basis = basis.reshape(basis.shape[0], -1) # (2 * seq_len - 1, 2 * num_freqs) x = torch.matmul(self.weight, basis.t()) assert x.shape == (num_heads, pos_dim, 2 * seq_len - 1) From f5b7eae7a0ed139e3640990070b6a10de38264c0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 13:25:52 +0800 Subject: [PATCH 0811/1191] Add testing for muon. --- egs/librispeech/ASR/zipformer/muon.py | 226 +++++++++++++++++++++++++ egs/librispeech/ASR/zipformer/optim.py | 102 +++++++++++ 2 files changed, 328 insertions(+) create mode 100644 egs/librispeech/ASR/zipformer/muon.py diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py new file mode 100644 index 0000000000..d7482c36b8 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -0,0 +1,226 @@ +# Copyright 2025 Moonshot AI and the LlamaFactory team. +# +# This code is based on the MoonshotAI's Moonlight library. +# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py +# and the Keller Jordan's Muon library. +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2025 Moonshot AI +# Copyright (c) 2024 Keller Jordan +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch + + +def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor": + """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + + We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. + For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing + the slope at zero even beyond the point where the iteration no longer converges all the way to + one everywhere on the interval. This iteration therefore does not produce UV^T but rather something + like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """Muon - MomentUm Orthogonalized by Newton-schulz. + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float: + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + # Muon loop + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + # Adam backup + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d4e26e82ac..923503e8ce 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1613,6 +1613,107 @@ def _test_transformed_adam(hidden_dim: int): logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") + +def _test_muon(hidden_dim: int): + import timeit + + from muon import Muon + from scaling import OrthogonalLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_muon") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + if True: + fix_random_seed(42) + Linear = torch.nn.Linear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + OrthogonalLinear(hidden_dim, hidden_dim, bias=True, + in_groups=2, group_size=hidden_dim//4), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + optim = Muon(muon_params=[m for m in m.parameters() if m.ndim == 2], + adamw_params=[m for m in m.parameters() if m.ndim != 2], + lr=1e-03) + + scheduler = Sched3(optim, lr_batches=100, power=0.9, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + # if epoch == 100 and test in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + scheduler.step_batch() + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[1].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[3].weight**2).mean().sqrt().item() + norm4 = '%.2e' % (m[5].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[3].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[5].bias**2).mean().sqrt().item() + + lr = scheduler.get_last_lr()[0] + logging.info( + f"Test muon, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) + loss.log().backward() + optim.step() + optim.zero_grad() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Muon: time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) @@ -1630,6 +1731,7 @@ def _test_transformed_adam(hidden_dim: int): else: hidden_dim = 200 + _test_muon(hidden_dim) _test_transformed_adam(hidden_dim) _test_eden() _test_sched3() From ebb0480a9cef2731c51a7e1031a733b411296aca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 13:40:05 +0800 Subject: [PATCH 0812/1191] Make train.py use muon --- egs/librispeech/ASR/zapformer/muon.py | 1 + egs/librispeech/ASR/zapformer/train.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) create mode 120000 egs/librispeech/ASR/zapformer/muon.py diff --git a/egs/librispeech/ASR/zapformer/muon.py b/egs/librispeech/ASR/zapformer/muon.py new file mode 120000 index 0000000000..847edc7f4c --- /dev/null +++ b/egs/librispeech/ASR/zapformer/muon.py @@ -0,0 +1 @@ +../zipformer/muon.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 061f95aa81..7b705c07c6 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -75,6 +75,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Sched3, TransformedAdam +from muon import Muon from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -440,7 +441,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." + "--base-lr", type=float, default=0.001, help="The base learning rate." ) parser.add_argument( @@ -1378,11 +1379,10 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = TransformedAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - debug_interval=params.debug_interval, + optimizer = Muon( + muon_params=[ m for m in model.parameters() if m.ndim==2], + adamw_params=[ m for m in model.parameters() if m.ndim!=2], + lr=params.base_lr, ) scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) From 4f49f7346869ec34c89c79529e997bc1c595d2f5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 13:55:11 +0800 Subject: [PATCH 0813/1191] remove print_debug_info --- egs/librispeech/ASR/zapformer/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7b705c07c6..d5f0e321cf 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1122,8 +1122,8 @@ def get_scaler_scale(): return 1.0 def save_bad_model(suffix: str = ""): - if params.debug_interval > 0: - optimizer.write_debug_info(summary_writer=tb_writer) + #if params.debug_interval > 0: + # optimizer.write_debug_info(summary_writer=tb_writer) save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, @@ -1277,8 +1277,8 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: - optimizer.write_debug_info(summary_writer=tb_writer) + #if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + #optimizer.write_debug_info(summary_writer=tb_writer) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value From b707d75c0890faeb685f143d1244205bc59f8366 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 14:02:25 +0800 Subject: [PATCH 0814/1191] Add scale for self_attn in_proj in zipformer.py --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index cf1b571f9d..7bfb31c33a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1084,9 +1084,9 @@ def forward( a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). """ - x = self.in_proj(x) query_head_dim = self.query_head_dim num_heads = self.num_heads + x = self.in_proj(x) * (query_head_dim ** -0.25) seq_len, batch_size, _ = x.shape From 30f2e31646b912e545e7d31e8e3b4f9bacf8133c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 14:16:27 +0800 Subject: [PATCH 0815/1191] Increas aux_loss_scale of attn scores penalty by 10. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7bfb31c33a..75482ed96a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1233,7 +1233,7 @@ def streaming_forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. - attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, 0.1 * aux_loss_scale, + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, key_padding_mask, self.name) attn_weights = attn_scores.softmax(dim=-1) From eb26e8b619d4f8ca7309b6e76552382df6700ca7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 14:20:26 +0800 Subject: [PATCH 0816/1191] Decrease warmup_start to 0.1. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 923503e8ce..6961fc4760 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1663,7 +1663,7 @@ def _test_muon(hidden_dim: int): adamw_params=[m for m in m.parameters() if m.ndim != 2], lr=1e-03) - scheduler = Sched3(optim, lr_batches=100, power=0.9, verbose=False) + scheduler = Sched3(optim, lr_batches=100, power=0.9, warmup_start=0.1, verbose=False) start = timeit.default_timer() avg_loss = 0.0 From e1df18c4747c9dbe0db1015d947140e556189d14 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 14:30:40 +0800 Subject: [PATCH 0817/1191] Introduce arbitrary factor of 0.2 into self_attn weights proj. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 75482ed96a..33fa3fddf4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1086,7 +1086,7 @@ def forward( """ query_head_dim = self.query_head_dim num_heads = self.num_heads - x = self.in_proj(x) * (query_head_dim ** -0.25) + x = self.in_proj(x) * (0.2 * (query_head_dim ** -0.25)) seq_len, batch_size, _ = x.shape From bd6f54797712c5e024eae1a339ab4bb34fa3281f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 14:45:49 +0800 Subject: [PATCH 0818/1191] Remove factor of 0.2 in multihead attn weights --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 33fa3fddf4..75482ed96a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1086,7 +1086,7 @@ def forward( """ query_head_dim = self.query_head_dim num_heads = self.num_heads - x = self.in_proj(x) * (0.2 * (query_head_dim ** -0.25)) + x = self.in_proj(x) * (query_head_dim ** -0.25) seq_len, batch_size, _ = x.shape From a4ca8f46fb323c1f6c12b9745d7c2e2d2e9aca01 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 15:53:03 +0800 Subject: [PATCH 0819/1191] Decrease weight-decay tenfold to 0.01. --- egs/librispeech/ASR/zapformer/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d5f0e321cf..7bedd96175 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1383,6 +1383,7 @@ def run(rank, world_size, args): muon_params=[ m for m in model.parameters() if m.ndim==2], adamw_params=[ m for m in model.parameters() if m.ndim!=2], lr=params.base_lr, + wd=0.01, ) scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) From 50203c6fe1cb5ce86cb4ed6ff327db64258da3b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 16:47:31 +0800 Subject: [PATCH 0820/1191] Use 4-norm instead of 2-norm for normalization before newton schulz. --- egs/librispeech/ASR/zipformer/muon.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index d7482c36b8..e492b4dd62 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -41,9 +41,18 @@ # SOFTWARE. import math - import torch +import logging + +def norm4(X): + XX = X @ X.T + import random + if random.random() < 0.0001: + norm2 = X.norm() + norm4 = XX.norm().sqrt() + logging.info(f"shape={X.shape}, norm2={norm2} vs norm4={norm4}") + return XX.norm().sqrt() def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor": """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. @@ -60,8 +69,8 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = G.bfloat16() if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) + # Ensure spectral 4-norm is at most 1 + X = X / (norm4(X) + 1e-7) # Perform the NS iterations for _ in range(steps): A = X @ X.T From a0c19d06f3f9d7846928619729f0898f7ad581c4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Jan 2026 23:24:01 +0800 Subject: [PATCH 0821/1191] Increase weight decay back to 0.1. --- egs/librispeech/ASR/zapformer/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7bedd96175..d5f0e321cf 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1383,7 +1383,6 @@ def run(rank, world_size, args): muon_params=[ m for m in model.parameters() if m.ndim==2], adamw_params=[ m for m in model.parameters() if m.ndim!=2], lr=params.base_lr, - wd=0.01, ) scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) From fdbacb9eded908be40ee24b6da58fdf7a2920a0d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 23 Jan 2026 11:23:33 +0800 Subject: [PATCH 0822/1191] Reduce power of Sched3 to 0.33. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7bedd96175..cae9386d04 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1386,7 +1386,7 @@ def run(rank, world_size, args): wd=0.01, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.33) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From f859268c93e8f73e44775f0995764c0d9a1dc4fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 25 Jan 2026 21:46:04 +0800 Subject: [PATCH 0823/1191] Initialize muon iterations by row and column normalization. --- egs/librispeech/ASR/zipformer/muon.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index e492b4dd62..4897a679af 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -69,8 +69,12 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = G.bfloat16() if G.size(0) > G.size(1): X = X.T + # now x: (rows, cols) with rows <= cols # Ensure spectral 4-norm is at most 1 - X = X / (norm4(X) + 1e-7) + eps = 1e-7 + X = X / ((X ** 2).sum(dim=1, keepdim=True) + eps**2).sqrt() + X = X / ((X ** 2).sum(dim=0) + eps**2).sqrt() + X = X / (norm4(X) + eps) # Perform the NS iterations for _ in range(steps): A = X @ X.T @@ -115,7 +119,7 @@ def __init__( muon_params=None, momentum=0.95, nesterov=True, - ns_steps=5, + ns_steps=3, adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, From 91dd323743d952001982c0c779f877c337d22399 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 25 Jan 2026 22:01:38 +0800 Subject: [PATCH 0824/1191] Increase steps back to 5. --- egs/librispeech/ASR/zipformer/muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 4897a679af..a9136cdcb3 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -119,7 +119,7 @@ def __init__( muon_params=None, momentum=0.95, nesterov=True, - ns_steps=3, + ns_steps=5, adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, From 65a64ad904e13947c0dcfd1dd4ac1b908a53d04c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 24 Jan 2026 17:18:59 +0800 Subject: [PATCH 0825/1191] Add debug printout to muon, about rms. --- egs/librispeech/ASR/zipformer/muon.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index a9136cdcb3..10560081cd 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -43,11 +43,13 @@ import math import torch import logging +import random + + def norm4(X): XX = X @ X.T - import random if random.random() < 0.0001: norm2 = X.norm() norm4 = XX.norm().sqrt() @@ -83,6 +85,10 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" if G.size(0) > G.size(1): X = X.T + + if random.random() < 0.01: + logging.info(f"zeropower_via_newtonschulz5: shape={X.shape}, singular-value-rms={X.norm()/(min(X.shape[0],X.shape[1])**0.5)}") + return X From 06a51171def4c41f8510aee74cd9bcb18aed0dcd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 26 Jan 2026 14:51:35 +0800 Subject: [PATCH 0826/1191] Use muon, with reshaping, for all params that have more than one non-trivial dimension --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- egs/librispeech/ASR/zipformer/muon.py | 27 ++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6516397c1e..f09205f6e6 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1380,8 +1380,8 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Muon( - muon_params=[ m for m in model.parameters() if m.ndim==2], - adamw_params=[ m for m in model.parameters() if m.ndim!=2], + muon_params=[ m for m in model.parameters() if m.numel() != max(m.shape, default=1) ], + adamw_params=[ m for m in model.parameters() if m.numel() == max(m.shape, default=1) ], lr=params.base_lr, ) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 10560081cd..afb384029c 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -56,6 +56,25 @@ def norm4(X): logging.info(f"shape={X.shape}, norm2={norm2} vs norm4={norm4}") return XX.norm().sqrt() +def get_muon_shape(shape): + shape = list(shape) + def prod(l): + ans = l[0] + for n in l[1:]: + ans = ans * n + return ans + n = len(shape) + diffs = [ ] + for i in range(1, n): + prod1 = prod(shape[:i]) + prod2 = prod(shape[i:]) + diff = abs(prod1 - prod2) + diffs.append(diff) + min_diff = min(diffs) + for i in range(1, n): + if diffs[i-1] == min_diff: + return prod(shape[:i]), prod(shape[i:]) + def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor": """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. @@ -66,6 +85,8 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ + orig_shape = G.shape + G = G.reshape(get_muon_shape(orig_shape)) assert len(G.shape) == 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() @@ -89,7 +110,7 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" if random.random() < 0.01: logging.info(f"zeropower_via_newtonschulz5: shape={X.shape}, singular-value-rms={X.norm()/(min(X.shape[0],X.shape[1])**0.5)}") - return X + return X.reshape(orig_shape) class Muon(torch.optim.Optimizer): @@ -147,7 +168,7 @@ def __init__( # Sort parameters into those for which we will use Muon, and those for which we will not for p in muon_params: # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim == 2, p.ndim + assert p.ndim > 1, p.ndim self.state[p]["use_muon"] = True for p in adamw_params: # Do not use Muon for parameters in adamw_params @@ -186,8 +207,6 @@ def step(self, closure=None): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) assert g is not None # calc update From 973662c7eacc02eea41b33a247a822d5b65832f9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 26 Jan 2026 17:16:54 +0800 Subject: [PATCH 0827/1191] Introduce factor of 10 on weight in FftConv. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 75482ed96a..ee44ce6a69 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1675,7 +1675,7 @@ def __init__(self, params_per_channel: int, bias: bool = True): super().__init__() - self.weight = nn.Parameter(torch.randn(num_channels, params_per_channel)) + self.weight = nn.Parameter(0.1 * torch.randn(num_channels, params_per_channel)) # one factor of 2 is for (sin, cos); the other is to double the num representable freqs self.weight_proj = nn.Linear(params_per_channel, 4 * params_per_channel) if bias: @@ -1693,6 +1693,10 @@ def forward(self, # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) + weight = 10. * weight + # this scale of 10 times is because of interactions with commonly + # used optimizers, it's to help this module learn faster than it + # otherwise would. weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) # weight: (N, num_channels) From 8805ad9403d702d5e5cda8867252acb165bc4c5a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Jan 2026 00:24:32 +0800 Subject: [PATCH 0828/1191] Increase weight decay from .1 to .15 --- egs/librispeech/ASR/zapformer/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index f09205f6e6..fba657e25e 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1383,6 +1383,7 @@ def run(rank, world_size, args): muon_params=[ m for m in model.parameters() if m.numel() != max(m.shape, default=1) ], adamw_params=[ m for m in model.parameters() if m.numel() == max(m.shape, default=1) ], lr=params.base_lr, + wd=0.15, ) scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.33) From 1452f699e1ecc45d4e17a712b85e8d0b455e5cc6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Jan 2026 12:25:06 +0800 Subject: [PATCH 0829/1191] Make muon only normalize columns before newton-schulz --- egs/librispeech/ASR/zipformer/muon.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index afb384029c..b142d45f21 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -95,7 +95,6 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" # now x: (rows, cols) with rows <= cols # Ensure spectral 4-norm is at most 1 eps = 1e-7 - X = X / ((X ** 2).sum(dim=1, keepdim=True) + eps**2).sqrt() X = X / ((X ** 2).sum(dim=0) + eps**2).sqrt() X = X / (norm4(X) + eps) # Perform the NS iterations From edeef7a03e341f4670a0878f84d7555f1ada9570 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Jan 2026 15:40:25 +0800 Subject: [PATCH 0830/1191] Introduce learnable scalar scale on input of each zipformer stack. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ee44ce6a69..d87a21faac 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -732,6 +732,7 @@ def __init__( (1. / num_layers) * torch.ones(num_layers) ], dim=0)) + self.input_scale = nn.Parameter(torch.tensor([1.0])) self.copy_bypass = Identity() @@ -771,10 +772,14 @@ def forward( residual_scale = limit_param_value(self.residual_scales[0], min=-1.0, max=-0.5) + input_scale = limit_param_value(self.input_scale, + min=0.5, max=2.0) src_with_bypass = residual_scale * src + src = input_scale * src for i, mod in enumerate(self.layers): + src = mod( src, chunk_size=chunk_size, From 4469fc3d229455e8b9e6bbce408007a9d960f697 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Jan 2026 16:26:01 +0800 Subject: [PATCH 0831/1191] Introduce ff2_scale=1.5 to tune scale of activations in feed_forward2 module. --- egs/librispeech/ASR/zipformer/zipformer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d87a21faac..cf9bc68219 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -581,7 +581,12 @@ def forward( src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + # ff2_scale is to keep the inputs to the activations in feed_forward2 + # module at about the same magnitude as in the feed_forward1 module + # without requiring the weights to learn different magnitudes (the + # activation is not scale-invariant, it is not relu). + ff2_scale = 1.5 + src = src + ff2_scale * self.feed_forward2(src * (1. / ff2_scale), aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) offset = (src - src_orig) * residual_scale From 64fbb5336a68a9788f49da41b09f95023e029cef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Jan 2026 19:03:07 +0800 Subject: [PATCH 0832/1191] Move normalization by columns to after orthogonalization --- egs/librispeech/ASR/zipformer/muon.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index b142d45f21..6339349786 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -92,10 +92,8 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = G.bfloat16() if G.size(0) > G.size(1): X = X.T - # now x: (rows, cols) with rows <= cols # Ensure spectral 4-norm is at most 1 eps = 1e-7 - X = X / ((X ** 2).sum(dim=0) + eps**2).sqrt() X = X / (norm4(X) + eps) # Perform the NS iterations for _ in range(steps): @@ -103,9 +101,14 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X + # now x: (rows, cols) with rows <= cols + scale = (X.shape[0] / X.shape[1]) ** 0.5 # adjust so overall scale is not changed by next line. + X = X * (scale / ((X ** 2).sum(dim=0) + eps**2).sqrt()) + if G.size(0) > G.size(1): X = X.T + if random.random() < 0.01: logging.info(f"zeropower_via_newtonschulz5: shape={X.shape}, singular-value-rms={X.norm()/(min(X.shape[0],X.shape[1])**0.5)}") From a287968662bf1b679f11540bab96dee5393ed98b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Jan 2026 14:19:46 +0800 Subject: [PATCH 0833/1191] Normalize columns of X both before and after orthogonalization. --- egs/librispeech/ASR/zipformer/muon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 6339349786..c598daae4e 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -94,6 +94,7 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = X.T # Ensure spectral 4-norm is at most 1 eps = 1e-7 + X = X / ((X ** 2).sum(dim=0) + eps**2).sqrt() # normalize columns X = X / (norm4(X) + eps) # Perform the NS iterations for _ in range(steps): From 6791ed9581fac5379891c307960705f2b71b69c4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Jan 2026 16:07:53 +0800 Subject: [PATCH 0834/1191] Inroduce lr_scale=0.66 for .proj of zipformer encoder in way that does not require get_parameter_groups_with_lrs() --- egs/librispeech/ASR/zipformer/scaling.py | 21 ++++++++++----------- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 92fbd0ce21..de0cc2c91e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1082,6 +1082,10 @@ class SimpleOrthogonalLinear(nn.Linear): Args: in_channels: number of input channels out_channels: number of output channels + lr_scale: we will scale the weight by this value before applying the orthogonal + constraint and using it; with most optimizers + this will have the effect of slowing down the learning by this factor because + the parameter value will be larger. bias: if True, include a bias term. penalty_scale: a scale on the penalty on non-orthogonality (this will be multiplied by the average-absolute-value of the @@ -1092,25 +1096,17 @@ class SimpleOrthogonalLinear(nn.Linear): def __init__(self, in_channels: int, out_channels: int, - in_groups: int = -1, - out_groups: int = -1, - group_size: int = -1, + lr_scale: float, bias: bool = True, penalty_scale: FloatLike = 20.0, ): super().__init__(in_channels, out_channels, bias=bias) self.name = None - self.in_groups = in_groups - self.out_groups = out_groups - if in_groups > 0 and group_size == -1: - group_size = in_channels // in_groups - elif out_groups > 0 and group_size == -1: - group_size = out_channels // out_groups - self.group_size = group_size self.penalty_scale = copy.deepcopy(penalty_scale) + self.weight_scale = lr_scale with torch.no_grad(): - self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) + self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * (1. / lr_scale) if self.bias is not None: torch.nn.init.uniform_(self.bias, -0.01, 0.01) @@ -1118,6 +1114,9 @@ def __init__(self, def forward(self, x: Tensor, transpose: bool = False): # you can only use transpose=True if you used bias=False in initialization weight = self.weight + weight_scale = self.weight_scale + if weight_scale != 1.0: + weight = weight * weight_scale if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): weight = SimpleOrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index cf9bc68219..eafcdb1b0d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -723,8 +723,8 @@ def __init__( super().__init__() # self.downsample will also reverse the downsampling operation for us afterward. - self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) - self.proj.lr_scale = 0.75 + self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, + lr_scale=0.66, bias=False) self.name = None self.layers = nn.ModuleList( From 7457681b1b839fd713eeea85274c5dccd0245cbd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 29 Jan 2026 12:51:10 +0800 Subject: [PATCH 0835/1191] Move the scale from self_attn_weights.in_proj to query, to no affect rel_pos; and reduce initial scale of RelPosScores. --- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index eafcdb1b0d..e6a566b4e7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1096,14 +1096,14 @@ def forward( """ query_head_dim = self.query_head_dim num_heads = self.num_heads - x = self.in_proj(x) * (query_head_dim ** -0.25) + x = self.in_proj(x) seq_len, batch_size, _ = x.shape query_dim = query_head_dim * num_heads # self-attention - q = x[..., 0:query_dim] + q = x[..., 0:query_dim] * (query_head_dim ** -0.5) k = x[..., query_dim : 2 * query_dim] p = x[..., 2 * query_dim:] @@ -1606,7 +1606,7 @@ def __init__(self, Implementation of relative position scores with mathematically sensible sinc envelope. """ super().__init__() - self.weight = nn.Parameter(0.2 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) + self.weight = nn.Parameter(0.04 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) with torch.no_grad(): # initialize the weight in a low-pass way. for _ in range(10): From 5c8abe22360de2f33fe9cc745b9be5c442e8d732 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 29 Jan 2026 21:41:14 +0800 Subject: [PATCH 0836/1191] introduce scales on decoder (10) and joiner (2) --- egs/librispeech/ASR/zipformer/decoder.py | 6 +++++- egs/librispeech/ASR/zipformer/joiner.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index bf49726b95..df41fe29ab 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -57,6 +57,10 @@ def __init__( num_embeddings=vocab_size, embedding_dim=decoder_dim, ) + with torch.no_grad(): + # and we will scale by 10 in forward. this is because with an optimizer that has weight decay, + # it's best if all the parameters have fairly similar dynamic range. + self.embedding.weight[:] *= 0.1 self.blank_id = blank_id @@ -92,7 +96,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: y = y.to(torch.int64) # this stuff about clamp() is a temporary fix for a mismatch # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) * 10.0 if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index 0406efe834..9638ab6b2a 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -62,6 +62,8 @@ def forward( else: logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) + # the scale of 2.0 is arbitrary, it is intended to modulate the speed at which joiner.output_linear trains, + # make it train faster by reducing its scale. + logit = 2.0 * self.output_linear(torch.tanh(logit)) return logit From 284bf13572a8ff11d494da34da3521164d203d49 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 31 Jan 2026 12:57:55 +0800 Subject: [PATCH 0837/1191] Make residual_scale a constant of 0.25 without removing the parameter to keep initialization the same. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e6a566b4e7..f68bf01b73 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -588,7 +588,8 @@ def forward( ff2_scale = 1.5 src = src + ff2_scale * self.feed_forward2(src * (1. / ff2_scale), aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) + #residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) + residual_scale = 0.25 offset = (src - src_orig) * residual_scale offset = self.offset_scale_limiter(offset, aux_loss_scale) From 2a780d6590262df0db507399fc2bd99d7ae057a3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 31 Jan 2026 15:37:40 +0800 Subject: [PATCH 0838/1191] Remove offset_scale_limiter. --- egs/librispeech/ASR/zipformer/zipformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f68bf01b73..0a28368bd4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -512,7 +512,6 @@ def __init__( self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) @@ -592,8 +591,6 @@ def forward( residual_scale = 0.25 offset = (src - src_orig) * residual_scale - offset = self.offset_scale_limiter(offset, aux_loss_scale) - src = src_orig + offset src = self.norm(src) From 1c1106c84bbe992bee14ccd7176b3e0886e5a820 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Feb 2026 21:41:22 +0800 Subject: [PATCH 0839/1191] Add learned scaling to muon, limits from .66 to 1.5. --- egs/librispeech/ASR/zipformer/muon.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index c598daae4e..403a438199 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -153,6 +153,7 @@ def __init__( adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, + scale_limits=(0.66, 1.5), ): defaults = dict( lr=lr, @@ -162,6 +163,7 @@ def __init__( ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, + scale_limits=scale_limits, ) params = list(muon_params) @@ -203,6 +205,7 @@ def step(self, closure=None): lr = group["lr"] wd = group["wd"] momentum = group["momentum"] + min_scale, max_scale = group["scale_limits"] # generate weight updates in distributed fashion for p in params: @@ -216,22 +219,36 @@ def step(self, closure=None): state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) + state["scale"] = torch.tensor(1.0, device=g.device) # scalar buf = state["momentum_buffer"] + scale = state["scale"] buf.mul_(momentum).add_(g) + if group["nesterov"]: g = g.add(buf, alpha=momentum) else: g = buf + + scale_grad = (g * p).sum() u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - # apply weight decay - p.data.mul_(1 - lr * wd) + + old_scale = scale.clone() + scale.mul_(1 - lr * wd) + + scale.add_(scale_grad.sign(), alpha=-lr) + scale.clamp_(min=min_scale, max=max_scale) + + scale_ratio = scale / old_scale + + # apply changes in scale + p.data.mul_(scale_ratio) # apply update - p.data.add_(u, alpha=-adjusted_lr) + p.data.add_(u * scale, alpha=-adjusted_lr) # Adam backup params = [p for p in group["params"] if not self.state[p]["use_muon"]] From ed8020bd0a6aa974ded0b399c9a74d4f29d73930 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Feb 2026 22:25:06 +0800 Subject: [PATCH 0840/1191] possibly a bug fix --- egs/librispeech/ASR/zipformer/muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 403a438199..3654127294 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -229,7 +229,7 @@ def step(self, closure=None): else: g = buf - scale_grad = (g * p).sum() + scale_grad = (g * p.detach()).sum() u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) # scale update From 6661ab86077f91d7f1ebef981eb5708e553b3371 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Feb 2026 22:31:39 +0800 Subject: [PATCH 0841/1191] Reduce print prob --- egs/librispeech/ASR/zipformer/muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 3654127294..3d5ef09477 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -110,7 +110,7 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = X.T - if random.random() < 0.01: + if random.random() < 0.0001: logging.info(f"zeropower_via_newtonschulz5: shape={X.shape}, singular-value-rms={X.norm()/(min(X.shape[0],X.shape[1])**0.5)}") return X.reshape(orig_shape) From d6e97c245b5dddfbab350c081d3056b981263938 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Feb 2026 23:04:13 +0800 Subject: [PATCH 0842/1191] Change to how weight decay interacts with scale. --- egs/librispeech/ASR/zipformer/muon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 3d5ef09477..c31c33d78d 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -237,15 +237,15 @@ def step(self, closure=None): old_scale = scale.clone() - scale.mul_(1 - lr * wd) scale.add_(scale_grad.sign(), alpha=-lr) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale - # apply changes in scale - p.data.mul_(scale_ratio) + # apply changes in scale, together with conventional decay. + p.data.mul_(scale_ratio * (1 - lr * wd)) + # apply update p.data.add_(u * scale, alpha=-adjusted_lr) From 6f1064e67299e760bc69a3ccf18b1c31dfcf2ea9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 2 Feb 2026 16:06:25 +0800 Subject: [PATCH 0843/1191] Fix issue with how momentum for scaling was calculated. --- egs/librispeech/ASR/zipformer/muon.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index c31c33d78d..a051234488 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -220,16 +220,20 @@ def step(self, closure=None): if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) state["scale"] = torch.tensor(1.0, device=g.device) # scalar + state["scale_grad_buffer"] = torch.tensor(0.0, device=g.device) # scalar buf = state["momentum_buffer"] scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] buf.mul_(momentum).add_(g) + scale_grad = (g * p.detach()).sum() + scale_grad_buf.mul_(momentum).add_(scale_grad) + if group["nesterov"]: g = g.add(buf, alpha=momentum) else: g = buf - scale_grad = (g * p.detach()).sum() u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) # scale update @@ -238,7 +242,7 @@ def step(self, closure=None): old_scale = scale.clone() - scale.add_(scale_grad.sign(), alpha=-lr) + scale.add_(scale_grad_buf.sign(), alpha=-lr) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale From 5fcea551c69594118c526a471e70ea77149e17fa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 13:03:39 +0800 Subject: [PATCH 0844/1191] Move multiplication by 10 in FftConv to before weight_proj. --- egs/librispeech/ASR/zipformer/zipformer.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0a28368bd4..7e892e0e89 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -574,18 +574,13 @@ def forward( aux_loss_scale=0.1 * aux_loss_scale, ) - src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + 0.5 * self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + src = src + self.conv_module(4. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - # ff2_scale is to keep the inputs to the activations in feed_forward2 - # module at about the same magnitude as in the feed_forward1 module - # without requiring the weights to learn different magnitudes (the - # activation is not scale-invariant, it is not relu). - ff2_scale = 1.5 - src = src + ff2_scale * self.feed_forward2(src * (1. / ff2_scale), aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + 0.5 * self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) #residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) residual_scale = 0.25 @@ -1701,7 +1696,6 @@ def forward(self, # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) - weight = 10. * weight # this scale of 10 times is because of interactions with commonly # used optimizers, it's to help this module learn faster than it # otherwise would. From 7e5f165a416b317b9cc64a4cf913eab8dca6b2f4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 16:30:42 +0800 Subject: [PATCH 0845/1191] Decrease in_proj scale of conv_module from 4 to 3 and reintroduce scale, of 4, on weight of depthwise_conv. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7e892e0e89..60b947a854 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -578,7 +578,7 @@ def forward( src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.conv_module(4. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + src = src + self.conv_module(3. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + 0.5 * self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -1695,7 +1695,8 @@ def forward(self, x = torch.fft.rfft(x.to(torch.float32), dim=0) # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs - weight = self.weight_proj(self.weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) + weight = 4. * self.weight + weight = self.weight_proj(weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) # this scale of 10 times is because of interactions with commonly # used optimizers, it's to help this module learn faster than it # otherwise would. From e1f3222e6223e0a7ef3ab45d55694c2dc12774e8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 17:12:49 +0800 Subject: [PATCH 0846/1191] Restore offset_scale_limiter, add factor of 0.2 in out projection of encoder_embed. --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 41d3cd9510..f97f785150 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -319,7 +319,7 @@ def streaming_forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, T', out_width * layer3_channels)) - x = self.out(x) + x = 0.2 * self.out(x) # Now x is of shape (N, T', odim) x = self.out_norm(x) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 60b947a854..1e178a528b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -510,7 +510,9 @@ def __init__( self.embed_dim = embed_dim self.name = None # will be set from training loop - self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + #self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) + + self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) @@ -586,6 +588,8 @@ def forward( residual_scale = 0.25 offset = (src - src_orig) * residual_scale + offset = self.offset_scale_limiter(offset, aux_loss_scale) + src = src_orig + offset src = self.norm(src) From d3950f45fb42d3ec3379037842dbf0d195afb62f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 17:17:59 +0800 Subject: [PATCH 0847/1191] Remove factor of query_head_dim**-0.5 from class MultiheadAttentionWeights in_proj output. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1e178a528b..a9f70ed46f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1100,7 +1100,7 @@ def forward( query_dim = query_head_dim * num_heads # self-attention - q = x[..., 0:query_dim] * (query_head_dim ** -0.5) + q = x[..., 0:query_dim] k = x[..., query_dim : 2 * query_dim] p = x[..., 2 * query_dim:] From 657466c12a0ca55e960f980e5b5040be8375ed4b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 17:23:54 +0800 Subject: [PATCH 0848/1191] Reduce scaler on out projection of encoder_embed from .2 to .1 --- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index f97f785150..21af6e71af 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -319,7 +319,7 @@ def streaming_forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, T', out_width * layer3_channels)) - x = 0.2 * self.out(x) + x = 0.1 * self.out(x) # Now x is of shape (N, T', odim) x = self.out_norm(x) From 9fbb90424361ecdc974b6121a0cfcebc9134fe69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 19:44:32 +0800 Subject: [PATCH 0849/1191] Fix to scale in subsampling.py --- egs/librispeech/ASR/zipformer/subsampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 21af6e71af..4b1df261d7 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -265,7 +265,7 @@ def forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, (T-7)//2, out_width * layer3_channels)) - x = self.out(x) + x = 0.15 * self.out(x) # Now x is of shape (N, (T-7)//2, odim) if torch.jit.is_scripting() or torch.jit.is_tracing(): x_lens = (x_lens - 7) // 2 @@ -319,7 +319,7 @@ def streaming_forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, T', out_width * layer3_channels)) - x = 0.1 * self.out(x) + x = 0.15 * self.out(x) # Now x is of shape (N, T', odim) x = self.out_norm(x) From d35ae83dbe7fcb9600923ab520bc99bf2c7947bf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 20:46:42 +0800 Subject: [PATCH 0850/1191] Introduce scale of 5 at output of encoder, and double scale of decoder embeddings. --- egs/librispeech/ASR/zipformer/decoder.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index df41fe29ab..fc6aec95e6 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -96,7 +96,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: y = y.to(torch.int64) # this stuff about clamp() is a temporary fix for a mismatch # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) * 10.0 + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) * 20.0 if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a9f70ed46f..04d8d3e5f3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -270,6 +270,7 @@ def forward( if od > 1: x_lens = (x_lens + od - 1) // od + x = 5.0 * x # scale up x, as the activations at this point 'want' to be fairly small, like 0.2. return x, x_lens def _get_attn_mask( @@ -370,6 +371,7 @@ def streaming_forward( warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 + x = 5.0 * x # scale up x, as the activations at this point 'want' to be fairly small, like 0.2. return x, lengths, new_states @torch.jit.export From a859ce641709ac8c783e9488461457ba24cc6f6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Feb 2026 23:56:58 +0800 Subject: [PATCH 0851/1191] Increase range of scaling_limits to 0.5 to 2. --- egs/librispeech/ASR/zipformer/muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index a051234488..7d36475178 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -153,7 +153,7 @@ def __init__( adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, - scale_limits=(0.66, 1.5), + scale_limits=(0.5, 2.0), ): defaults = dict( lr=lr, From 2d7d77e127581187135538336e0ce7fae2dd5e8a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Feb 2026 11:10:23 +0800 Subject: [PATCH 0852/1191] Increase power in sched3 from 0.33 to 0.4. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fba657e25e..4d5cf94d47 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1386,7 +1386,7 @@ def run(rank, world_size, args): wd=0.15, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.33) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.4) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From f99dca0403b8ca43d1e25ab5e8b9b6be8e701aba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Feb 2026 11:11:22 +0800 Subject: [PATCH 0853/1191] Reduce scale on out projection of encoder_embed from .15 to .1 --- egs/librispeech/ASR/zipformer/subsampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 4b1df261d7..f6f6f0cd7d 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -265,7 +265,7 @@ def forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, (T-7)//2, out_width * layer3_channels)) - x = 0.15 * self.out(x) + x = 0.1 * self.out(x) # Now x is of shape (N, (T-7)//2, odim) if torch.jit.is_scripting() or torch.jit.is_tracing(): x_lens = (x_lens - 7) // 2 @@ -319,7 +319,7 @@ def streaming_forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, T', out_width * layer3_channels)) - x = 0.15 * self.out(x) + x = 0.1 * self.out(x) # Now x is of shape (N, T', odim) x = self.out_norm(x) From be496c74c87d6372b53c5247386a931d422f5df6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Feb 2026 14:20:04 +0800 Subject: [PATCH 0854/1191] Replace ExpNorm with GaussNorm; also add GaussNorm on in_proj of self-attention weights. --- egs/librispeech/ASR/zipformer/scaling.py | 114 +++++++++++++++++++ egs/librispeech/ASR/zipformer/subsampling.py | 15 ++- egs/librispeech/ASR/zipformer/zipformer.py | 7 +- 3 files changed, 126 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index de0cc2c91e..fbad1b8be7 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -440,6 +440,120 @@ def forward(self, x: Tensor) -> Tensor: return ans + +def round_up_to_power_of_two(a): + if a <= 0: + return 1 + else: + return 1 << (a - 1).bit_length() + +def gaussian_blur_1d(x, inv_width, dim): + T = x.shape[dim] + roundT = round_up_to_power_of_two(T) + if roundT > T: + x = torch.cat((x, torch.narrow(x, dim, 0, roundT - T)), dim=dim) + # now x length is power of 2. + seq_len = x.shape[dim] + x = torch.fft.rfft(x.to(torch.float32), dim=dim) + # x is complex. + N = x.shape[dim] + freq = torch.arange(N, device=x.device) / (N - 1) # this is proportional to normalized frequency betwen 0 and 1 + for _ in range(dim, x.ndim - 1): + freq = freq.unsqueeze(-1) + scale = (-(freq * inv_width) ** 2).exp() + x = x * scale # down-weight higher frequencies + x = torch.fft.irfft(x, n=seq_len, dim=dim) + x = torch.narrow(x, dim, 0, T) + return x + + +# assume layout: (time, batch, channel) +def _gauss_norm(x: Tensor, blur: Tensor, scale: Tensor): + eps = 1.0e-02 + x_sq = torch.mean(x ** 2, dim=2, keepdim=True).clamp(min=eps) + x_sq_blurred = gaussian_blur_1d(x_sq, blur, dim=0) + x_sq = torch.maximum(x_sq_blurred.clamp(min=eps), 0.2 * x_sq) # may be overkill + + scales = scale / x_sq.sqrt() + return x * scales + + +class GaussNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + blur: Tensor, + scale: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, blur, scale) + return _gauss_norm(x, blur, scale) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, blur, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x, blur, scale = x.to(torch.float32), blur.to(torch.float32), scale.to(torch.float32) + x, blur, scale = x.detach(), blur.detach(), scale.detach() + + x.requires_grad = True + scale.requires_grad = True + blur.requires_grad = True + + with torch.enable_grad(): + ans = _gauss_norm(x, blur, scale) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode. + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(blur.grad), c(scale.grad) + + +class GaussNorm(torch.nn.Module): + """ + This is like RMSNorm with a trainable scale, but also blurs the rms values along + the time axis by convolving with a learnable width of Gaussian. + + """ + def __init__( + self, + ) -> None: + super(GaussNorm, self).__init__() + self.scale = nn.Parameter(torch.tensor(0.2)) # output scale + self.blur = nn.Parameter(torch.tensor(0.5)) # larger value -> more blur, will multiply this by 20, then it's like an inverse width. + self.name = None + + + def forward(self, x: Tensor) -> Tensor: + # Assumes layout is (time, batch, channel) + + blur_factor = 20.0 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _gauss_norm(x, self.blur * blur_factor, self.scale) + + scale = limit_param_value( + self.scale, min=0.1, max=1.0, training=self.training) + + blur = blur_factor * limit_param_value( + self.blur, min=0.0, max=3.0, training=self.training) + + ans = GaussNormFunction.apply( + x, blur, scale, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, blur={blur.item()}, scale={scale.item()}") + + return ans + + def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index f6f6f0cd7d..74e16ade21 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -23,7 +23,7 @@ from scaling import ( ScaleLimiter, ScaledLinear, - ExpNorm, + GaussNorm, FloatLike, get_max_similarity, ScaledConv2d, @@ -231,13 +231,9 @@ def __init__( # scale it up a bit, else the output is quite small. - self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, - initial_scale=4.0) + self.out = ScaledLinear(self.out_width * layer3_channels, out_channels) - - self.scale_limiter = ScaleLimiter(max_rms=2.0) - - self.out_norm = ExpNorm(out_channels) + self.out_norm = GaussNorm() def forward( self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, @@ -276,8 +272,11 @@ def forward( key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) # key_padding_mask: (N, (T-7)//2) - x = self.scale_limiter(x, aux_loss_scale) + x = x.permute(1, 0, 2) + # x: (time, batch, channels) x = self.out_norm(x) + x = x.permute(1, 0, 2) + # x: (batch, time, channels) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 04d8d3e5f3..bf96aa3010 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -31,7 +31,7 @@ SimpleOrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, - ExpNorm, + GaussNorm, ChunkCausalDepthwiseConv1d, CosineSimilarityLoss, ScheduledFloat, @@ -534,7 +534,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) - self.norm = ExpNorm(embed_dim) + self.norm = GaussNorm() def forward( @@ -1054,6 +1054,8 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. + self.in_norm = GaussNorm() + key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads @@ -1095,6 +1097,7 @@ def forward( """ query_head_dim = self.query_head_dim num_heads = self.num_heads + x = self.in_norm(x) x = self.in_proj(x) seq_len, batch_size, _ = x.shape From 31ede5ec5e44e7617ed0a730a5434e61175d7d7e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Feb 2026 23:41:01 +0800 Subject: [PATCH 0855/1191] Remove various scales introduced in 2005: the 0.5 in feedforward, 0.1 in subsampling.py, 5.0 in zipformer output. decrease min GaussNorm scale from .1 to .05. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/subsampling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fbad1b8be7..3aa688fae9 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -537,7 +537,7 @@ def forward(self, x: Tensor) -> Tensor: return _gauss_norm(x, self.blur * blur_factor, self.scale) scale = limit_param_value( - self.scale, min=0.1, max=1.0, training=self.training) + self.scale, min=0.05, max=1.0, training=self.training) blur = blur_factor * limit_param_value( self.blur, min=0.0, max=3.0, training=self.training) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 74e16ade21..e888069972 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -261,7 +261,7 @@ def forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, (T-7)//2, out_width * layer3_channels)) - x = 0.1 * self.out(x) + x = self.out(x) # Now x is of shape (N, (T-7)//2, odim) if torch.jit.is_scripting() or torch.jit.is_tracing(): x_lens = (x_lens - 7) // 2 diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bf96aa3010..a767f5770c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -270,7 +270,6 @@ def forward( if od > 1: x_lens = (x_lens + od - 1) // od - x = 5.0 * x # scale up x, as the activations at this point 'want' to be fairly small, like 0.2. return x, x_lens def _get_attn_mask( @@ -371,7 +370,6 @@ def streaming_forward( warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - x = 5.0 * x # scale up x, as the activations at this point 'want' to be fairly small, like 0.2. return x, lengths, new_states @torch.jit.export @@ -578,13 +576,13 @@ def forward( aux_loss_scale=0.1 * aux_loss_scale, ) - src = src + 0.5 * self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module(3. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - src = src + 0.5 * self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) #residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) residual_scale = 0.25 From 6dcd109e5c5d23f816162c712074aa8d2f0b7c78 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 10:50:41 +0800 Subject: [PATCH 0856/1191] Introduce learnable, additive eps into gauss_norm. --- egs/librispeech/ASR/zipformer/scaling.py | 39 +++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3aa688fae9..b6e079f485 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -468,12 +468,11 @@ def gaussian_blur_1d(x, inv_width, dim): # assume layout: (time, batch, channel) -def _gauss_norm(x: Tensor, blur: Tensor, scale: Tensor): - eps = 1.0e-02 - x_sq = torch.mean(x ** 2, dim=2, keepdim=True).clamp(min=eps) +def _gauss_norm(x: Tensor, blur: Tensor, eps: Tensor, scale: Tensor): + eps_sq = eps * eps + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) x_sq_blurred = gaussian_blur_1d(x_sq, blur, dim=0) - x_sq = torch.maximum(x_sq_blurred.clamp(min=eps), 0.2 * x_sq) # may be overkill - + x_sq = x_sq_blurred.relu() + eps_sq scales = scale / x_sq.sqrt() return x * scales @@ -484,26 +483,28 @@ def forward( ctx, x: Tensor, blur: Tensor, + eps: Tensor, scale: Tensor, ) -> Tensor: - ctx.save_for_backward(x, blur, scale) - return _gauss_norm(x, blur, scale) + ctx.save_for_backward(x, blur, eps, scale) + return _gauss_norm(x, blur, eps, scale) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, blur, scale = ctx.saved_tensors + x, blur, eps, scale = ctx.saved_tensors with torch.amp.autocast('cuda', enabled=False): - x, blur, scale = x.to(torch.float32), blur.to(torch.float32), scale.to(torch.float32) - x, blur, scale = x.detach(), blur.detach(), scale.detach() + x, blur, eps, scale = x.to(torch.float32), blur.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, blur, eps, scale = x.detach(), blur.detach(), eps.detach(), scale.detach() x.requires_grad = True - scale.requires_grad = True blur.requires_grad = True + eps.requires_grad = True + scale.requires_grad = True with torch.enable_grad(): - ans = _gauss_norm(x, blur, scale) + ans = _gauss_norm(x, blur, eps, scale) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -511,7 +512,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(blur.grad), c(scale.grad) + return x.grad, c(blur.grad), c(eps.grad), c(scale.grad) class GaussNorm(torch.nn.Module): @@ -526,6 +527,7 @@ def __init__( super(GaussNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(0.2)) # output scale self.blur = nn.Parameter(torch.tensor(0.5)) # larger value -> more blur, will multiply this by 20, then it's like an inverse width. + self.eps = nn.Parameter(torch.tensor(0.1)) self.name = None @@ -534,22 +536,25 @@ def forward(self, x: Tensor) -> Tensor: blur_factor = 20.0 if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _gauss_norm(x, self.blur * blur_factor, self.scale) + return _gauss_norm(x, self.blur * blur_factor, self.eps, self.scale) scale = limit_param_value( self.scale, min=0.05, max=1.0, training=self.training) + eps = limit_param_value( + self.eps, min=0.0, max=10.0, training=self.training) + blur = blur_factor * limit_param_value( self.blur, min=0.0, max=3.0, training=self.training) ans = GaussNormFunction.apply( - x, blur, scale, + x, blur, eps, scale, ) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, blur={blur.item()}, scale={scale.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, blur={blur.item()}, eps={eps.item()}, scale={scale.item()}") return ans @@ -1210,7 +1215,7 @@ class SimpleOrthogonalLinear(nn.Linear): def __init__(self, in_channels: int, out_channels: int, - lr_scale: float, + lr_scale: float = 1.0, bias: bool = True, penalty_scale: FloatLike = 20.0, ): From 927853ec2fc4622939272f3b2b09326c89497558 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 12:48:29 +0800 Subject: [PATCH 0857/1191] Remove normalization in frontend. --- egs/librispeech/ASR/zipformer/subsampling.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index e888069972..0ac0919f3b 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -23,7 +23,6 @@ from scaling import ( ScaleLimiter, ScaledLinear, - GaussNorm, FloatLike, get_max_similarity, ScaledConv2d, @@ -233,8 +232,6 @@ def __init__( # scale it up a bit, else the output is quite small. self.out = ScaledLinear(self.out_width * layer3_channels, out_channels) - self.out_norm = GaussNorm() - def forward( self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -274,7 +271,6 @@ def forward( # key_padding_mask: (N, (T-7)//2) x = x.permute(1, 0, 2) # x: (time, batch, channels) - x = self.out_norm(x) x = x.permute(1, 0, 2) # x: (batch, time, channels) @@ -318,10 +314,7 @@ def streaming_forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, T', out_width * layer3_channels)) - x = 0.1 * self.out(x) # Now x is of shape (N, T', odim) - x = self.out_norm(x) - if torch.jit.is_scripting() or torch.jit.is_tracing(): assert self.convnext.padding[0] == 3 # The ConvNeXt module needs 3 frames of right padding after subsampling From ab8d24dbfc37dbaea4328f89b2c22f20a34aa225 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 13:37:52 +0800 Subject: [PATCH 0858/1191] Insert factor of 0.15 into subsampling.py --- egs/librispeech/ASR/zipformer/subsampling.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 0ac0919f3b..839e848e69 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -268,15 +268,10 @@ def forward( x_lens = (x_lens - 7) // 2 key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) - # key_padding_mask: (N, (T-7)//2) - x = x.permute(1, 0, 2) - # x: (time, batch, channels) - x = x.permute(1, 0, 2) - # x: (batch, time, channels) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) - return x, x_lens + return 0.15 * x, x_lens def streaming_forward( self, @@ -328,7 +323,7 @@ def streaming_forward( assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) - return x, x_lens, cached_left_pad + return 0.15 * x, x_lens, cached_left_pad @torch.jit.export def get_init_states( From 9dae0fc311bf0cdbce5e77d339392c9a1391a2fc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 18:15:19 +0800 Subject: [PATCH 0859/1191] Set minimum blur to 1, not 0. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b6e079f485..b212b17481 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -545,7 +545,7 @@ def forward(self, x: Tensor) -> Tensor: self.eps, min=0.0, max=10.0, training=self.training) blur = blur_factor * limit_param_value( - self.blur, min=0.0, max=3.0, training=self.training) + self.blur, min=0.05, max=3.0, training=self.training) ans = GaussNormFunction.apply( x, blur, eps, scale, From ac80d810e062bb4ec48fabd5731d12fb75e6cdcf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 18:21:54 +0800 Subject: [PATCH 0860/1191] min_blur of 1 not used in self-attn in_proj. --- egs/librispeech/ASR/zipformer/scaling.py | 4 +++- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b212b17481..8d18ff8fe2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -523,11 +523,13 @@ class GaussNorm(torch.nn.Module): """ def __init__( self, + min_blur: float = 0.0, ) -> None: super(GaussNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(0.2)) # output scale self.blur = nn.Parameter(torch.tensor(0.5)) # larger value -> more blur, will multiply this by 20, then it's like an inverse width. self.eps = nn.Parameter(torch.tensor(0.1)) + self.min_blur = min_blur self.name = None @@ -545,7 +547,7 @@ def forward(self, x: Tensor) -> Tensor: self.eps, min=0.0, max=10.0, training=self.training) blur = blur_factor * limit_param_value( - self.blur, min=0.05, max=3.0, training=self.training) + self.blur, min=self.min_blur/blur_factor, max=3.0, training=self.training) ans = GaussNormFunction.apply( x, blur, eps, scale, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a767f5770c..297235dd57 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -532,7 +532,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) - self.norm = GaussNorm() + self.norm = GaussNorm(min_blur=1.0) def forward( From 9905c522bdf22622bc89e209b352afcacaf7841d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 20:58:24 +0800 Subject: [PATCH 0861/1191] Use SequenceNorm at layer level and RmsNorm in self_attn_weights --- egs/librispeech/ASR/zipformer/scaling.py | 132 ++++++++------------- egs/librispeech/ASR/zipformer/zipformer.py | 9 +- 2 files changed, 54 insertions(+), 87 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8d18ff8fe2..6ec750c03c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,33 +333,36 @@ def backward(ctx, x_grad, *args): -def _exp_norm(x: Tensor, scale: Tensor, channel_dim: int): - x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - num = (x_norm + 0.05).tanh() - scales = num / x_norm - scales = scale * scales - return (x * scales) - -class ExpNormFunction(torch.autograd.Function): +def _sequence_norm(x: Tensor, scale: Tensor, mask: Optional[Tensor]): + if mask is None: + scales = 1.0 / (x ** 2).mean(dim=(0, 2), keepdim=True).sqrt() + else: + mask = mask.to(torch.float).t().unsqueeze(-1) + x = x * mask + num_frames = mask.sum(dim=0) + scales = num_frames / (x ** 2).sum(dim=0).mean(dim=1, keepdim=True).sqrt() + + return x * (scale * scales) + + +class SequenceNormFunction(torch.autograd.Function): @staticmethod def forward( ctx, x: Tensor, scale: Tensor, - channel_dim: int, + mask: Optional[Tensor], ) -> Tensor: - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x, scale) + ctx.mask = mask - return _exp_norm(x, scale, channel_dim) + return _sequence_norm(x, scale, mask) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: x, scale = ctx.saved_tensors + mask = ctx.mask with torch.amp.autocast('cuda', enabled=False): x, scale = x.to(torch.float32), scale.to(torch.float32) @@ -369,7 +372,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: scale.requires_grad = True with torch.enable_grad(): - ans = _exp_norm(x, scale, ctx.channel_dim) + ans = _sequence_norm(x, scale, ctx.mask) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -380,7 +383,7 @@ def c(x): return x.grad, c(scale.grad), None -class ExpNorm(torch.nn.Module): +class SequenceNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for LayerNorm, without the learned weight or bias. There is just one learned @@ -409,28 +412,25 @@ class ExpNorm(torch.nn.Module): """ def __init__( self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. ) -> None: - super(ExpNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.scale = nn.Parameter(torch.tensor(1.7)) + super(SequenceNorm, self).__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) self.name = None - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels + def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + # x: (seq, batch, channel) + # mask: bool, (batch_size, seq_len) if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _exp_norm(x, self.scale, self.channel_dim) + return _sequence_norm(x, self.scale, mask) scale = limit_param_value( - self.scale, min=0.8, max=2.5, training=self.training) + self.scale, min=0.05, max=1.0, training=self.training) - ans = ExpNormFunction.apply( - x, scale, self.channel_dim, + ans = SequenceNormFunction.apply( + x, scale, mask, ) if random.random() < 0.002: @@ -441,70 +441,44 @@ def forward(self, x: Tensor) -> Tensor: return ans -def round_up_to_power_of_two(a): - if a <= 0: - return 1 - else: - return 1 << (a - 1).bit_length() - -def gaussian_blur_1d(x, inv_width, dim): - T = x.shape[dim] - roundT = round_up_to_power_of_two(T) - if roundT > T: - x = torch.cat((x, torch.narrow(x, dim, 0, roundT - T)), dim=dim) - # now x length is power of 2. - seq_len = x.shape[dim] - x = torch.fft.rfft(x.to(torch.float32), dim=dim) - # x is complex. - N = x.shape[dim] - freq = torch.arange(N, device=x.device) / (N - 1) # this is proportional to normalized frequency betwen 0 and 1 - for _ in range(dim, x.ndim - 1): - freq = freq.unsqueeze(-1) - scale = (-(freq * inv_width) ** 2).exp() - x = x * scale # down-weight higher frequencies - x = torch.fft.irfft(x, n=seq_len, dim=dim) - x = torch.narrow(x, dim, 0, T) - return x - # assume layout: (time, batch, channel) -def _gauss_norm(x: Tensor, blur: Tensor, eps: Tensor, scale: Tensor): - eps_sq = eps * eps - x_sq = torch.mean(x ** 2, dim=2, keepdim=True) - x_sq_blurred = gaussian_blur_1d(x_sq, blur, dim=0) - x_sq = x_sq_blurred.relu() + eps_sq +def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) scales = scale / x_sq.sqrt() return x * scales -class GaussNormFunction(torch.autograd.Function): +class GaussNorm: + # this is to prevent errors when running multiple jobs. + pass + +class RmsNormFunction(torch.autograd.Function): @staticmethod def forward( ctx, x: Tensor, - blur: Tensor, eps: Tensor, scale: Tensor, ) -> Tensor: - ctx.save_for_backward(x, blur, eps, scale) - return _gauss_norm(x, blur, eps, scale) + ctx.save_for_backward(x, eps, scale) + return _rms_norm(x, eps, scale) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, blur, eps, scale = ctx.saved_tensors + x, eps, scale = ctx.saved_tensors with torch.amp.autocast('cuda', enabled=False): - x, blur, eps, scale = x.to(torch.float32), blur.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - x, blur, eps, scale = x.detach(), blur.detach(), eps.detach(), scale.detach() + x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, eps, scale = x.detach(), eps.detach(), scale.detach() x.requires_grad = True - blur.requires_grad = True eps.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - ans = _gauss_norm(x, blur, eps, scale) + ans = _rms_norm(x, eps, scale) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -512,33 +486,28 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(blur.grad), c(eps.grad), c(scale.grad) + return x.grad, c(eps.grad), c(scale.grad) -class GaussNorm(torch.nn.Module): +class RmsNorm(torch.nn.Module): """ - This is like RMSNorm with a trainable scale, but also blurs the rms values along - the time axis by convolving with a learnable width of Gaussian. + This is like RMSNorm with a trainable scale. """ def __init__( self, - min_blur: float = 0.0, ) -> None: - super(GaussNorm, self).__init__() + super(RmsNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(0.2)) # output scale - self.blur = nn.Parameter(torch.tensor(0.5)) # larger value -> more blur, will multiply this by 20, then it's like an inverse width. self.eps = nn.Parameter(torch.tensor(0.1)) - self.min_blur = min_blur self.name = None def forward(self, x: Tensor) -> Tensor: # Assumes layout is (time, batch, channel) - blur_factor = 20.0 if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _gauss_norm(x, self.blur * blur_factor, self.eps, self.scale) + return _rms_norm(x, self.eps, self.scale) scale = limit_param_value( self.scale, min=0.05, max=1.0, training=self.training) @@ -546,17 +515,14 @@ def forward(self, x: Tensor) -> Tensor: eps = limit_param_value( self.eps, min=0.0, max=10.0, training=self.training) - blur = blur_factor * limit_param_value( - self.blur, min=self.min_blur/blur_factor, max=3.0, training=self.training) - - ans = GaussNormFunction.apply( - x, blur, eps, scale, + ans = RmsNormFunction.apply( + x, eps, scale, ) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, blur={blur.item()}, eps={eps.item()}, scale={scale.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, eps={eps.item()}, scale={scale.item()}") return ans diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 297235dd57..06bfa7edb1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -31,7 +31,6 @@ SimpleOrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ActivationDropoutAndLinear, - GaussNorm, ChunkCausalDepthwiseConv1d, CosineSimilarityLoss, ScheduledFloat, @@ -46,6 +45,8 @@ ) try: from scaling import CorrelationLimiter + from scaling import SequenceNorm + from scaling import RmsNorm except: pass @@ -532,7 +533,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) - self.norm = GaussNorm(min_blur=1.0) + self.norm = SequenceNorm() def forward( @@ -592,7 +593,7 @@ def forward( src = src_orig + offset - src = self.norm(src) + src = self.norm(src, src_key_padding_mask) return src @@ -1052,7 +1053,7 @@ def __init__( self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. - self.in_norm = GaussNorm() + self.in_norm = RmsNorm() key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads From 2409e1ce5aedbf246a72560b60d2a17f18d2a6a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 21:10:09 +0800 Subject: [PATCH 0862/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6ec750c03c..fbd94a5406 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -337,7 +337,7 @@ def _sequence_norm(x: Tensor, scale: Tensor, mask: Optional[Tensor]): if mask is None: scales = 1.0 / (x ** 2).mean(dim=(0, 2), keepdim=True).sqrt() else: - mask = mask.to(torch.float).t().unsqueeze(-1) + mask = (~mask).to(torch.float).t().unsqueeze(-1) x = x * mask num_frames = mask.sum(dim=0) scales = num_frames / (x ** 2).sum(dim=0).mean(dim=1, keepdim=True).sqrt() From ac468628d955a72f9b4cb4f2ec5dcade5950b688 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Feb 2026 21:30:33 +0800 Subject: [PATCH 0863/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fbd94a5406..0e245e0311 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -340,7 +340,7 @@ def _sequence_norm(x: Tensor, scale: Tensor, mask: Optional[Tensor]): mask = (~mask).to(torch.float).t().unsqueeze(-1) x = x * mask num_frames = mask.sum(dim=0) - scales = num_frames / (x ** 2).sum(dim=0).mean(dim=1, keepdim=True).sqrt() + scales = (num_frames / (x ** 2).sum(dim=0).mean(dim=1, keepdim=True)).sqrt() return x * (scale * scales) From 19866819d47414d1c40e03df4339d7515f14c36a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 7 Feb 2026 00:20:33 +0800 Subject: [PATCH 0864/1191] Revert optimizer to version in 1898 (TransformedAdam with decay using scale_by()); name lr_scale as lr_scale so it works with that optimizer. --- egs/librispeech/ASR/zapformer/train.py | 23 +++++++++++------------ egs/librispeech/ASR/zipformer/scaling.py | 8 ++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 4d5cf94d47..061f95aa81 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -75,7 +75,6 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Sched3, TransformedAdam -from muon import Muon from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -441,7 +440,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.001, help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -1122,8 +1121,8 @@ def get_scaler_scale(): return 1.0 def save_bad_model(suffix: str = ""): - #if params.debug_interval > 0: - # optimizer.write_debug_info(summary_writer=tb_writer) + if params.debug_interval > 0: + optimizer.write_debug_info(summary_writer=tb_writer) save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, @@ -1277,8 +1276,8 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - #if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: - #optimizer.write_debug_info(summary_writer=tb_writer) + if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: + optimizer.write_debug_info(summary_writer=tb_writer) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -1379,14 +1378,14 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = Muon( - muon_params=[ m for m in model.parameters() if m.numel() != max(m.shape, default=1) ], - adamw_params=[ m for m in model.parameters() if m.numel() == max(m.shape, default=1) ], - lr=params.base_lr, - wd=0.15, + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.4) + scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 0e245e0311..620d1760b3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1190,7 +1190,7 @@ def __init__(self, super().__init__(in_channels, out_channels, bias=bias) self.name = None self.penalty_scale = copy.deepcopy(penalty_scale) - self.weight_scale = lr_scale + self.lr_scale = lr_scale with torch.no_grad(): self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * (1. / lr_scale) @@ -1201,9 +1201,9 @@ def __init__(self, def forward(self, x: Tensor, transpose: bool = False): # you can only use transpose=True if you used bias=False in initialization weight = self.weight - weight_scale = self.weight_scale - if weight_scale != 1.0: - weight = weight * weight_scale + lr_scale = self.lr_scale + if lr_scale != 1.0: + weight = weight * lr_scale if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): weight = SimpleOrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) From 3aaa03da9a50c9c66c4d88352efac2bd1e53cc35 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 7 Feb 2026 00:40:33 +0800 Subject: [PATCH 0865/1191] Add eps in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6961fc4760..9a07db6cfe 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -233,7 +233,7 @@ def scale_by(x, beta1): # target_ratio is the ratio between the variance we want, to the variance we got # with this alpha value. it - target_ratio = (beta1_2 * x2_sum) / (x2_sum - 2 * alpha * x4_sum + alpha**2 * x6_sum) + target_ratio = (beta1_2 * x2_sum + eps) / (x2_sum - 2 * alpha * x4_sum + alpha**2 * x6_sum + eps) post_scale = target_ratio ** 0.5 # post-scaling on x, after applying alpha. From d2cc6f34653f56e3a78573c9a5565f07ea450789 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Feb 2026 15:52:43 +0800 Subject: [PATCH 0866/1191] Replace Sched3 with cosine learning rate schedule with 2000 batches of warmup from 0.5. --- egs/librispeech/ASR/zapformer/train.py | 36 +++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 061f95aa81..d4188f8366 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -56,6 +56,7 @@ import copy import logging import warnings +import math from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -75,6 +76,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Sched3, TransformedAdam +from torch.optim.lr_scheduler import LambdaLR from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -116,22 +118,6 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: ) -def get_adjusted_lr_batches(params: AttributeDict) -> float: - # returns an adjusted form of the "lr_batches" parameter used to set the learning - # rate in the Sched3 scheduler. - # We want the final LR to be based on the geometric mean of "how much data we - # have seen" and "how many batches we have seen". - # an easier way to look at it is this: the formula for learning rate depends - # on (cur_batch / lr_batches). if we write this as: - # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches - # then the numerator is a geometric mean of "how many batches we have seen" - # and "how much data we have seen". We can achieve this by setting - # lr_batches = params.lr_batches * (duration_ratio ** -0.5). - duration_ratio = (params.max_duration * params.world_size) / params.ref_duration - lr_batches = params.lr_batches * (duration_ratio ** -0.5) - logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") - return lr_batches - def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): @@ -1161,7 +1147,7 @@ def save_bad_model(suffix: str = ""): # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) + scheduler.step() scaler.step(optimizer) scaler.update() @@ -1385,7 +1371,21 @@ def run(rank, world_size, args): debug_interval=params.debug_interval, ) - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) + warmup_steps = 2000 + # hardcode batches per epoch for now. + total_steps = 4550 * params.num_epochs + warmup_start = 0.5 + def lr_lambda(current_step): + if current_step < warmup_steps: + # Linear warm-up + return warmup_start + (1.0 - warmup_start) * current_step / warmup_steps + else: + # Cosine annealing + progress = (current_step - warmup_steps) / (total_steps - warmup_steps) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + scheduler = LambdaLR(optimizer, lr_lambda) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From ed0a88d58666f98581c86ab502e634d009014dae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Feb 2026 18:38:28 +0800 Subject: [PATCH 0867/1191] Add RmsNorm at output of zipformer. --- egs/librispeech/ASR/zipformer/zipformer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 06bfa7edb1..7859658d66 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -170,6 +170,9 @@ def _to_tuple(x): self.encoders = nn.ModuleList(encoders) + self.out_norm = RmsNorm() + + def get_chunk_info(self) -> Tuple[int, int]: """ Returns chunk_size and left_context_chunks. @@ -271,6 +274,8 @@ def forward( if od > 1: x_lens = (x_lens + od - 1) // od + x = self.out_norm(x) + return x, x_lens def _get_attn_mask( From 1ef50bc3cde046c1f492147e9fb4cc2d0e48bd4d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Feb 2026 21:52:28 +0800 Subject: [PATCH 0868/1191] Update TransformedAdam to deal with scaling factors more like our modified muon and change default base-lr to 0.001. --- egs/librispeech/ASR/zapformer/train.py | 7 +- egs/librispeech/ASR/zipformer/optim.py | 805 +++---------------------- 2 files changed, 91 insertions(+), 721 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 061f95aa81..f32798b30d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -440,7 +440,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." + "--base-lr", type=float, default=0.001, help="The base learning rate." ) parser.add_argument( @@ -1380,9 +1380,8 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - debug_interval=params.debug_interval, + lr=params.base_lr, + wd=0.15, ) scheduler = Sched3(optimizer, get_adjusted_lr_batches(params), power=0.5) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 9a07db6cfe..6fdcbfa1e6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -19,6 +19,7 @@ import logging import random from collections import defaultdict +from torch.optim.lr_scheduler import LambdaLR from typing import Dict, List, Optional, Tuple, Union import torch @@ -254,9 +255,11 @@ def momentum_step(group, state, grad): # delta is the normalized gradient; the rms of delta should be around 1. lr = group["lr"] + eps = group["eps"] step = state["step"] beta1 = min(group["beta1"], 1. - 1. / (10. + 0.2 * step)) direct = group["direct"] + min_scale, max_scale = group["scale_limits"] try: stored_delta = state["delta"] @@ -272,145 +275,70 @@ def momentum_step(group, state, grad): stored_delta.mul_(beta1) else: scale_by(stored_delta, beta1) - return ((-lr * (1-direct) * (1-beta1)) * stored_delta) + ((-lr * direct) * delta) - -def basic_momentum_step(group, state, grad, lr, beta): - delta = base_step(group, state, grad) + ans = (((1-direct) * (1-beta1)) * stored_delta) + (direct * delta) + # OK, now divide ans by its rms so it has unit rms + norm_ans = False + if norm_ans: + dims = list(range(1, ans.ndim)) + ans = ans / ((ans ** 2).mean(dim=dims, keepdim=True) + eps).sqrt() + return -lr * ans - step = state["step"] - try: - stored_delta = state["delta"] - except KeyError as e: - assert step < 2 - # scalar. use conventional momentum. - stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) - state["delta"] = stored_delta - - stored_delta.add_(delta) - stored_delta.mul_(beta) - delta = (-lr * (1 - beta)) * stored_delta - return delta - -def get_scale(group, p, grad): - is_weight = (p.ndim > 2) # is weight, not bias. for scalars, we do not - # reach here. 1st dim is batch-of-params dim. - min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] - - dims = tuple(range(1, p.ndim)) - abs_mean = p.abs().mean(dim=dims, keepdim=True) - abs_mean = abs_mean.clamp(min=min_scale) - - scale = abs_mean - - log_scale_grad = (p * grad).sum(dim=dims, keepdim=True) - - return scale, log_scale_grad - - - -def scaling_step(group, p, state, grad): - # returns new parameter. - p_shape = p.shape - - scale, log_scale_grad = get_scale(group, p, grad) +def scaling_step(group, param, state, grad): + delta = momentum_step(group, state, grad) + # delta is the normalized gradient; the rms of delta should be around 1. + lr = group["lr"] + wd = group["wd"] try: - scale_state = state["scale"] + scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] except: - scale_state = dict() - state["scale"] = scale_state - scale_state["step"] = state["step"] - - scale_lr = group["lr"] * group["scaling_lr_scale"] - delta_log_scale = basic_momentum_step(group, scale_state, log_scale_grad, - lr=scale_lr, beta=0.9) - # the following is decay of the log scale towards a user-specified default value, like - # AdamW but on the log of the scale. - delta_log_scale = delta_log_scale - (scale_lr * group["scale_decay"]) * (scale.log() - math.log(group["scale_default"])) - - is_weight = (p.ndim > 2) - max_scale = group["weight_max_scale"] if is_weight else group["bias_max_scale"] - min_scale = group["weight_min_scale"] if is_weight else group["bias_min_scale"] - new_scale = (scale * (1. + delta_log_scale)).clamp(min=min_scale, max=max_scale) + shape = [ param.shape[0] ] + [1] * (param.ndim - 1) + scale = torch.ones(*shape, device=grad.device) + scale_grad_buf = torch.zeros(*shape, device=grad.device) + state["scale"] = scale + state["scale_grad_buffer"] = scale_grad_buf - delta = momentum_step(group, state, grad) - - return p * (new_scale / scale) + delta * scale - - - -def debug_step(group, p, state, grad): - debug_interval = group["debug_interval"] - debug_buffer_size = 256 - step = state["step"] - - if p.shape[0] == p.numel(): - p = p + basic_momentum_step(group, state, grad, lr=group["lr"]*group["scalar_lr_scale"], beta=0.9) - else: - p = scaling_step(group, p, state, grad) + momentum = 0.95 + min_scale, max_scale = group["scale_limits"] - if debug_interval == 0 or step % debug_interval != 0: - return p + dims = list(range(1, param.ndim)) - try: - debug_info = state["debug_info"] - except KeyError: - debug_info = torch.zeros(debug_buffer_size, p.shape[0], 2, - device=p.device, dtype=torch.float) - state["debug_info"] = debug_info + scale_grad = (grad * param.detach()).sum(dim=dims, keepdim=True) + scale_grad_buf.mul_(momentum).add_(scale_grad) - is_scalar = (p.numel() == p.shape[0]) + old_scale = scale.clone() - dims = list(range(1, p.ndim)) # e.g. dims to average. - def maybe_rms(x): - if is_scalar: - # the .mean() is just to get rid of those dims. - return x.mean(dim=dims) - else: - return (x ** 2).mean(dim=dims).sqrt() + scale.add_(scale_grad_buf.sign(), alpha=-lr) + scale.clamp_(min=min_scale, max=max_scale) + scale_ratio = scale / old_scale - debug_info = debug_info[(step // debug_interval) % debug_buffer_size] + delta_scale = (scale_ratio * (1 - lr * wd)) - 1 + return param * delta_scale + scale * delta - debug_info[:, 0] = maybe_rms(p) - debug_info[:, 1] = maybe_rms(grad) - - return p +def basic_momentum_step(group, state, grad, lr, beta): + delta = base_step(group, state, grad) -def _write_debug_info(group, state, param_names, summary_writer): - """ - Writes to a Tensorboard, model-debugging information that was accumulated in debug_step. - """ - debug_interval = group["debug_interval"] + step = state["step"] try: - cur_step = state["step"] - debug_info = state["debug_info_cpu"] - except KeyError: - return - - (debug_buffer_size, num_params, _two) = debug_info.shape - - # cur_index would be where the next debug_info would go in the buffer - cur_index = (cur_step // debug_interval) % debug_buffer_size - # roll the data in the buffer so that cur_index goes to position zero. - debug_info = torch.roll(debug_info, -cur_index, 0) + stored_delta = state["delta"] + except KeyError as e: + assert step < 2 + # scalar. use conventional momentum. + stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["delta"] = stored_delta - debug_info = debug_info.to('cpu') + stored_delta.add_(delta) + stored_delta.mul_(beta) - assert len(param_names) == num_params + delta = (-lr * (1 - beta)) * stored_delta + return delta - arange = torch.arange(debug_buffer_size) - steps = debug_interval * (arange - debug_buffer_size) + cur_step - for i, legend in enumerate(['param_rms', 'grad_rms']): - for name, info in zip(param_names, debug_info[..., i].unbind(dim=1)): - debug_str = f"debug/{legend}/{name}" - for step, value in zip(steps.tolist(), info.tolist()): - if step >= 0: - summary_writer.add_scalar(debug_str, value, step) class TransformedAdam(BatchedOptimizer): @@ -429,14 +357,6 @@ class TransformedAdam(BatchedOptimizer): lr: The learning rate. We will typically use a learning rate schedule that starts at 0.03 and decreases over time, i.e. much higher than other common optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. Must satisfy 0 < beta <= beta2 < 1. betas: a list of decay constants for momentum on the parameter-change @@ -452,66 +372,34 @@ class TransformedAdam(BatchedOptimizer): scale_default: A constant that dictates the RMS value to which weight magnitudes decay. scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. eps: A general-purpose epsilon to prevent division by zero - weight_min_scale, weight_max_scale: Minimum and maximum respectively of weight tensor - scales (mean-absolute-value), for purposes of - learning the scale on the parameters. Weight tensors, as distinct from bias - tensors and scalars, are defined as anything with more than one element and ndim > 1. - bias_min_scale, bias_max_scale: Minimum and maximum respetively of bias tensor scales, - defined as anything with more than one element and exactly one tensor dimension i.e. - ndim == 1. - debug_interval: if >0, write some statistics to tensorboard every this-many steps. """ def __init__( self, params, lr=3e-02, - clipping_scale=None, beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - scale_decay=0.01, - scale_default=0.05, - scalar_lr_scale=0.2, - scaling_lr_scale=0.2, + wd=0.15, eps=1.0e-08, - weight_min_scale=0.005, - weight_max_scale=1.0, - bias_min_scale=1.0e-05, - bias_max_scale=5.0, - clipping_update_period=100, - debug_interval=0, + scale_limits=(0.5, 2.0), ): defaults = dict( lr=lr, - clipping_scale=clipping_scale, beta1=beta1, direct=direct, beta2=beta2, - scale_decay=scale_decay, - scale_default=scale_default, - scalar_lr_scale=scalar_lr_scale, - scaling_lr_scale=scaling_lr_scale, eps=eps, - weight_min_scale=weight_min_scale, - bias_max_scale=bias_max_scale, - bias_min_scale=bias_min_scale, - weight_max_scale=weight_max_scale, - clipping_update_period=clipping_update_period, - debug_interval=debug_interval, + wd=wd, + scale_limits=scale_limits, ) - # If params only contains parameters or group of parameters, - # i.e when parameter names are not given, - # this flag will be set to False in funciton _get_names_of_parameters. - self.show_dominant_parameters = True param_groups, parameters_names = self._get_names_of_parameters(params) super(TransformedAdam, self).__init__(param_groups, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names - - def _get_names_of_parameters( self, params_or_named_params ) -> Tuple[List[Dict], List[List[str]]]: @@ -619,6 +507,8 @@ def _get_names_of_parameters( return param_groups, param_groups_names + + def __setstate__(self, state): super(TransformedAdam, self).__setstate__(state) @@ -638,29 +528,10 @@ def step(self, closure=None): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - for p, state, _names in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "TransformedAdam optimizer does not support sparse gradients" - ) - try: cur_step = state["step"] @@ -668,8 +539,10 @@ def step(self, closure=None): state["step"] = 0 cur_step = 0 - grad = (p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)) - p[:] = debug_step(group, p.detach(), state, grad) + if p.numel() == p.shape[0]: + p += basic_momentum_step(group, state, grad, group["lr"], group["beta1"]) + else: + p += scaling_step(group, p.detach(), state, grad) state["step"] = cur_step + 1 @@ -930,55 +803,26 @@ class SimpleTransformedAdam(Optimizer): scales: a list of scales corresponding to each of the betas, that we multiply each momentum-update by. Implicitly there is also a beta=0, scale=1, i.e. a non-decayed update. - scaling_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each non-scalar parameter tensor. If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. - eps: A general-purpose epsilon to prevent division by zero - weight_min_rms: Minimum root-mean-square value of weight tensors, for purposes of - learning the scale on the parameters. Weight tensors are defined - as anything with more than one element and ndim > 1. - bias_min_rms: Minimum root-mean-square value of bias tensors, defined as anything with - more than one element and exactly one tensor dimension i.e. ndim == 1. """ - def __init__( self, params, lr=3e-02, - clipping_scale=None, beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - scale_decay=0.01, - scale_default=0.05, - scalar_lr_scale=0.1, - scaling_lr_scale=0.2, + wd=0.15, eps=1.0e-08, - weight_min_scale=0.005, - weight_max_scale=1.0, - bias_min_scale=1.0e-05, - bias_max_scale=5.0, - debug_interval=0, + scale_limits=(0.5, 2.0), ): - defaults = dict( lr=lr, - clipping_scale=clipping_scale, beta1=beta1, direct=direct, beta2=beta2, - scale_decay=scale_decay, - scale_default=scale_default, - scalar_lr_scale=scalar_lr_scale, - scaling_lr_scale=scaling_lr_scale, eps=eps, - weight_min_scale=weight_min_scale, - bias_max_scale=bias_max_scale, - bias_min_scale=bias_min_scale, - weight_max_scale=weight_max_scale, - debug_interval=debug_interval, + wd=wd, + scale_limits=scale_limits, ) super().__init__(params, defaults) @@ -1007,7 +851,6 @@ def step(self, closure=None): state = self.state[p] grad = p.grad - try: cur_step = state["step"] except KeyError: @@ -1016,498 +859,17 @@ def step(self, closure=None): def u(x): return x.unsqueeze(0) - p[:] = debug_step(group, u(p.detach()), state, u(grad))[0] + + if p.numel() == 1: + p += basic_momentum_step(group, state, grad, group["lr"], group["beta1"]) + else: + p += scaling_step(group, u(p.detach()), state, u(grad))[0] state["step"] = cur_step + 1 return loss -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - # the user might try to override the base_lr, so don't include this in the state. - # previously they were included. - # "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # the things with base_lrs are a work-around for a previous problem - # where base_lrs were written with the state dict. - base_lrs = self.base_lrs - self.__dict__.update(state_dict) - self.base_lrs = base_lrs - - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.warning( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - If you don't have the concept of epochs, or one epoch takes a very long time, - you can replace the notion of 'epoch' with some measure of the amount of data - processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to - some measure representing "quite a lot of data": say, one fifth or one third - of an entire training run, but it doesn't matter much. You could also use - Eden2 which has only the notion of batches. - - We suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start - + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -class Eden2(LRScheduler): - """ - Eden2 scheduler, simpler than Eden because it does not use the notion of epoch, - only batches. - - The basic formula (before warmup) is: - lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup - - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super().__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.5 - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start - + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - - - -class Sched3(LRScheduler): - """ - Sched3 scheduler. - - The basic formula is as follows. p is a supplied power, e.g. 1.0, but could - also be, say, 0.8. lr_batches is a number of batches that defines when we start - decreasing significantly. "batch" is the current batch count. - - lr = warmup * [ (p * lr_batches / batch)^p if batch > p*e*lr_batches, else - exp(-batch / (e * lr_batches))) - - where e is the mathematical constant e. This expression is equivalent to: - factor = min_q [ (q * lr_batches) / batch)^q ] where the minimum is taken over - the continuous range 0 <= q <= p. The left hand side of the min in the formula - for lr corresponds to q == p, i.e. we hit the rhs of the allowed range. - - * notes for derivation: define x == lr_batches/batch, and let factor=min_q [(q*x) -. In wolframalpha.com, note that: - d/dp (q * x)^q has a root at (q = 1/(ex)). If 1/ex > p, then q is fixed to the limit, - q==p, so factor == (p * x)^p. Else, i.e. when 1/ex <= p, - when p > 1 / ex, factor == (q * x)^1 = (1/(ex)*x)^(1/ex) = (1/e)^(1/ex = e^{-1/ex}. - - So the rule is: - if batch/(e*lr_batches) > p, i.e. if batch > p*e*lr_batches, - factor = (p * lr_batches/batch)^p. - else, factor = exp(-batch/(lr_batches*e)) - Plot[ If [ x > 0.8 * Exp[1] * 10, 0.8*10/x, Exp[-x/(10*Exp[1])] ], {x, 0, 50}] - - - - - - - `warmup` increases linearly from warmup_start to 1 over `warmup_batches` batches - and then stays constant at 1. - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with TransformedAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - power: float = 1.0, - verbose: bool = False, - ): - super().__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.warmup_batches = warmup_batches - self.power = power - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - lr_batches = self.lr_batches - e = 2.71828 - batch = self.batch - p = self.power - factor = ((p * lr_batches / batch) ** p if batch > p * e * lr_batches else - e ** (-batch / (e * lr_batches))) - - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start - + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - - - - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = TransformedAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -def _test_sched3(): - m = torch.nn.Linear(100, 100) - optim = TransformedAdam(m.parameters(), lr=0.03) - - scheduler = Sched3(optim, lr_batches=100, power=0.5, verbose=True, warmup_batches=20) - - for step in range(300): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - if step % 10 == 0: - logging.info(f"test_sched3: step={step}, last lr = {scheduler.get_last_lr()}") - - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for TransformedAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - def _test_transformed_adam(hidden_dim: int): import timeit @@ -1529,7 +891,7 @@ def _test_transformed_adam(hidden_dim: int): input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - for test in [0, 1, 2]: + for test in [0, 1]: fix_random_seed(42) Linear = torch.nn.Linear @@ -1553,16 +915,28 @@ def _test_transformed_adam(hidden_dim: int): for _ in range(20) ] + lr = 0.001 if test == 0: - optim = SimpleTransformedAdam(m.parameters(), lr=0.075, eps=1.0e-20) + optim = TransformedAdam(m.named_parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.95) elif test == 1: - optim = TransformedAdam(m.named_parameters(), lr=0.075, clipping_scale=2.0, eps=1.0e-20) - elif test == 2: - optim = Eve(m.parameters(), lr=0.003) - else: - assert "unknown test", test + optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.95) + + num_epochs = 180 + + warmup_steps = 0 + # hardcode batches per epoch for now. + total_steps = num_epochs + warmup_start = 0.5 + def lr_lambda(current_step): + if current_step < warmup_steps: + # Linear warm-up + return warmup_start + (1.0 - warmup_start) * current_step / warmup_steps + else: + # Cosine annealing + progress = (current_step - warmup_steps) / (total_steps - warmup_steps) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - scheduler = Sched3(optim, lr_batches=120, power=0.9, verbose=False) + scheduler = LambdaLR(optim, lr_lambda) start = timeit.default_timer() avg_loss = 0.0 @@ -1577,7 +951,7 @@ def _test_transformed_adam(hidden_dim: int): # diagnostic = diagnostics.attach_diagnostics(m, opts) for n, (x, y) in enumerate(train_pairs): - scheduler.step_batch() + #scheduler.step_batch() y_out = m(x) loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: @@ -1601,13 +975,12 @@ def _test_transformed_adam(hidden_dim: int): loss.log().backward() optim.step() optim.zero_grad() + scheduler.step() # step once per epoch # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Test={test}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") # logging.info("state dict = ", scheduler.state_dict()) # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") @@ -1731,7 +1104,5 @@ def _test_muon(hidden_dim: int): else: hidden_dim = 200 - _test_muon(hidden_dim) + #_test_muon(hidden_dim) _test_transformed_adam(hidden_dim) - _test_eden() - _test_sched3() From bbd0064e22b19dd8088177f3aba47d4d790ca3e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Feb 2026 22:49:00 +0800 Subject: [PATCH 0869/1191] Fix import --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6a72a24f02..2ede342ae4 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -75,7 +75,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel -from optim import Sched3, TransformedAdam +from optim import TransformedAdam from torch.optim.lr_scheduler import LambdaLR from scaling import ScheduledFloat from subsampling import Conv2dSubsampling From 310f8d7d996738d1372a2fd67616b1dd506ec8f7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Feb 2026 22:54:55 +0800 Subject: [PATCH 0870/1191] Bug fix --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2ede342ae4..e2d625a0fc 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -105,7 +105,7 @@ str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler def get_adjusted_batch_count(params: AttributeDict) -> float: From 55c2e8e120b192eb1ee19b191c69c579a667d684 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Feb 2026 13:17:06 +0800 Subject: [PATCH 0871/1191] Add Sched3 as dummy class --- egs/librispeech/ASR/zipformer/optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6fdcbfa1e6..33a762255a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -27,6 +27,8 @@ from torch import Tensor from torch.optim import Optimizer +class Sched3: + pass # fixing multiple-experimental run issue with imports. class BatchedOptimizer(Optimizer): """ From 17155b7b4a7580342a2d447f2e6f1a2f3e951bd1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 7 Feb 2026 14:49:18 +0800 Subject: [PATCH 0872/1191] Increase num layers from 5,7,18,7 to 6,8,20,8. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index e2d625a0fc..3f44d45e5b 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -171,7 +171,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="5,7,18,7", + default="6,8,20,8", help="Number of zipformer encoder layers per stack, comma separated.", ) From 9985a618f2c4937384887c3e9f1394edaea57c09 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 13:28:18 +0800 Subject: [PATCH 0873/1191] Remove final RmsNorm on zipformer. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7859658d66..cfbeb41fed 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -170,8 +170,6 @@ def _to_tuple(x): self.encoders = nn.ModuleList(encoders) - self.out_norm = RmsNorm() - def get_chunk_info(self) -> Tuple[int, int]: """ @@ -274,8 +272,6 @@ def forward( if od > 1: x_lens = (x_lens + od - 1) // od - x = self.out_norm(x) - return x, x_lens def _get_attn_mask( From 1485c636e4aa5eeda76efe5b871b4cb8b065a4bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 14:24:45 +0800 Subject: [PATCH 0874/1191] Decrease warmup_start from .5 to .25. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 3f44d45e5b..b193013717 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1373,7 +1373,7 @@ def run(rank, world_size, args): warmup_steps = 2000 # hardcode batches per epoch for now. total_steps = 4550 * params.num_epochs - warmup_start = 0.5 + warmup_start = 0.25 def lr_lambda(current_step): if current_step < warmup_steps: # Linear warm-up From 76f59bf067c355c6ea4e99dbff056c94c3d46724 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 14:26:30 +0800 Subject: [PATCH 0875/1191] Remove non-functional debugging code. --- egs/librispeech/ASR/zapformer/train.py | 24 ------------------------ egs/librispeech/ASR/zipformer/optim.py | 23 ----------------------- 2 files changed, 47 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b193013717..52d7c207b4 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -388,26 +388,6 @@ def get_parser(): """, ) - parser.add_argument( - "--debug-interval", - type=int, - default=10, - help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for - finding parts of the network that are diverging or not well trained. - """ - ) - - parser.add_argument( - "--dump-debug-interval", - type=int, - default=0, - help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they - are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. - Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file - opened, seeked-in and closed for each scalar that is written). - """ - ) - parser.add_argument( "--exp-dir", type=str, @@ -1107,8 +1087,6 @@ def get_scaler_scale(): return 1.0 def save_bad_model(suffix: str = ""): - if params.debug_interval > 0: - optimizer.write_debug_info(summary_writer=tb_writer) save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, @@ -1262,8 +1240,6 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: - optimizer.write_debug_info(summary_writer=tb_writer) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 33a762255a..3860bf9e20 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -551,29 +551,6 @@ def step(self, closure=None): return loss - @torch.no_grad() - def write_debug_info(self, summary_writer): - if summary_writer is None: - return - logging.info("Writing debug info to tensorboard.") - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - for _p, state, names in batches: - try: - state["debug_info_cpu"] = state["debug_info"].to(device="cpu", non_blocking=True) - except: - pass - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - for _p, state, names in batches: - _write_debug_info(group, state, names, summary_writer) - try: - del state["debug_info_cpu"] - except: - pass - def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: From d03307fc9ff4a18ab139f62c86d02ee616f750ba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 16:28:38 +0800 Subject: [PATCH 0876/1191] Add eps to SequenceNorm; make the eps in RmsNorm not squared. --- egs/librispeech/ASR/zipformer/scaling.py | 60 ++++++++++-------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 620d1760b3..1ed0469841 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,14 +333,14 @@ def backward(ctx, x_grad, *args): -def _sequence_norm(x: Tensor, scale: Tensor, mask: Optional[Tensor]): +def _sequence_norm(x: Tensor, eps: Tensor, scale: Tensor, mask: Optional[Tensor]): if mask is None: - scales = 1.0 / (x ** 2).mean(dim=(0, 2), keepdim=True).sqrt() + scales = 1.0 / ((x ** 2).mean(dim=(0, 2), keepdim=True) + eps).sqrt() else: mask = (~mask).to(torch.float).t().unsqueeze(-1) x = x * mask num_frames = mask.sum(dim=0) - scales = (num_frames / (x ** 2).sum(dim=0).mean(dim=1, keepdim=True)).sqrt() + scales = (num_frames / ((x ** 2).sum(dim=0) + eps).mean(dim=1, keepdim=True)).sqrt() return x * (scale * scales) @@ -351,25 +351,27 @@ def forward( ctx, x: Tensor, scale: Tensor, + eps: Tensor, mask: Optional[Tensor], ) -> Tensor: - ctx.save_for_backward(x, scale) + ctx.save_for_backward(x, eps, scale) ctx.mask = mask - return _sequence_norm(x, scale, mask) + return _sequence_norm(x, eps, scale, mask) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors + x, eps, scale = ctx.saved_tensors mask = ctx.mask with torch.amp.autocast('cuda', enabled=False): - x, scale = x.to(torch.float32), scale.to(torch.float32) - x, scale = x.detach(), scale.detach() + x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, eps, scale = x.detach(), eps.detach(), scale.detach() x.requires_grad = True scale.requires_grad = True + eps.requires_grad = True with torch.enable_grad(): ans = _sequence_norm(x, scale, ctx.mask) @@ -380,41 +382,24 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(scale.grad), None + return x.grad, c(eps.grad), c(scale.grad), None class SequenceNorm(torch.nn.Module): """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm, without the learned weight or bias. There is just one learned - parameter, a scalar, which is a scale on the output; and it is limited - during training to the range [0.5..2.5]. - - Unlike LayerNorm it does not pick the scale that maps any rms value at the - input to an rms value of 1 at the output, i.e. the function f(x) = 1 (which - discards the length information); instead, it uses the function: - f(x) = scale * (1 - (-x).exp()), - i.e. if the input rms value was x, it gets mapped to the f(x) above. The - implementation is just: - - x_norm = torch.mean(x ** 2, dim=channel_dim, keepdim=True).sqrt() - scales = (1. - (-x_norm).exp()) / x_norm - return (x * scale * scales) + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + as well as the channels; and a padding mask is used for irregular length sequences (actually, + the mask is applied multiplicatively as well.) - where 'scale' is a scalar, and the only learned parameter. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interpreted as an offset from the input's ndim if negative. - This is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. + There is also a learnable scalar scale and a learnable "eps" value. """ def __init__( self, ) -> None: super(SequenceNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(0.5)) + self.eps = nn.Parameter(torch.tensor(0.1)) + self.name = None @@ -424,19 +409,22 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # mask: bool, (batch_size, seq_len) if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _sequence_norm(x, self.scale, mask) + return _sequence_norm(x, self.eps, self.scale, mask) scale = limit_param_value( self.scale, min=0.05, max=1.0, training=self.training) + eps = limit_param_value( + self.eps, min=0.0, max=10.0, training=self.training) + ans = SequenceNormFunction.apply( - x, scale, mask, + x, eps, scale, mask, ) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, eps={self.eps.item()}") return ans @@ -444,7 +432,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): - x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + eps.relu() scales = scale / x_sq.sqrt() return x * scales From 047bb98aa147afd336001053a6155289a69c9ceb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 16:43:10 +0800 Subject: [PATCH 0877/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 1ed0469841..a27a4e8df2 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -374,7 +374,7 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: eps.requires_grad = True with torch.enable_grad(): - ans = _sequence_norm(x, scale, ctx.mask) + ans = _sequence_norm(x, eps, scale, ctx.mask) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): From ae99ead8aeb238b4b0f495a7177f59fffb46c431 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 17:18:17 +0800 Subject: [PATCH 0878/1191] Fix regarding clamping of eps. --- egs/librispeech/ASR/zipformer/scaling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index a27a4e8df2..5a805f9632 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,12 +335,12 @@ def backward(ctx, x_grad, *args): def _sequence_norm(x: Tensor, eps: Tensor, scale: Tensor, mask: Optional[Tensor]): if mask is None: - scales = 1.0 / ((x ** 2).mean(dim=(0, 2), keepdim=True) + eps).sqrt() + scales = 1.0 / ((x ** 2).mean(dim=(0, 2), keepdim=True) + eps).clamp(min=1.0e-05).sqrt() else: mask = (~mask).to(torch.float).t().unsqueeze(-1) x = x * mask num_frames = mask.sum(dim=0) - scales = (num_frames / ((x ** 2).sum(dim=0) + eps).mean(dim=1, keepdim=True)).sqrt() + scales = (num_frames / ((x ** 2).sum(dim=0) + eps).mean(dim=1, keepdim=True)).clamp(min=1.0e-05).sqrt() return x * (scale * scales) @@ -432,8 +432,8 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): - x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + eps.relu() - scales = scale / x_sq.sqrt() + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + eps + scales = (scale / x_sq).clamp(min=1.0e-05).sqrt() return x * scales From 8ae22cc98d75fd2ecb14ac114d0dcc1e83f87048 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 17:46:08 +0800 Subject: [PATCH 0879/1191] Reduce initial value of eps --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5a805f9632..f87743fdaa 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -398,7 +398,7 @@ def __init__( ) -> None: super(SequenceNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(0.5)) - self.eps = nn.Parameter(torch.tensor(0.1)) + self.eps = nn.Parameter(torch.tensor(0.0001)) self.name = None From ba3431fb0950169844628a75fb5394e16cc23b73 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 18:36:00 +0800 Subject: [PATCH 0880/1191] Multiply eps by itself; bug fix. --- egs/librispeech/ASR/zipformer/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index f87743fdaa..da3ac1a84e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -335,12 +335,12 @@ def backward(ctx, x_grad, *args): def _sequence_norm(x: Tensor, eps: Tensor, scale: Tensor, mask: Optional[Tensor]): if mask is None: - scales = 1.0 / ((x ** 2).mean(dim=(0, 2), keepdim=True) + eps).clamp(min=1.0e-05).sqrt() + scales = 1.0 / ((x ** 2).mean(dim=(0, 2), keepdim=True) + eps * eps).clamp(min=1.0e-05).sqrt() else: mask = (~mask).to(torch.float).t().unsqueeze(-1) x = x * mask num_frames = mask.sum(dim=0) - scales = (num_frames / ((x ** 2).sum(dim=0) + eps).mean(dim=1, keepdim=True)).clamp(min=1.0e-05).sqrt() + scales = ((num_frames / ((x ** 2).mean(dim=2, keepdim=True).sum(dim=0))) + eps * eps).clamp(min=1.0e-05).sqrt() return x * (scale * scales) @@ -432,7 +432,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): - x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + eps + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) scales = (scale / x_sq).clamp(min=1.0e-05).sqrt() return x * scales From 29bb2378b2f119adfc66428e0f2f80755fc8a73c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 18:48:21 +0800 Subject: [PATCH 0881/1191] Change initial scale from .5 to 1.0 --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index da3ac1a84e..e9ceaad98a 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -397,7 +397,7 @@ def __init__( self, ) -> None: super(SequenceNorm, self).__init__() - self.scale = nn.Parameter(torch.tensor(0.5)) + self.scale = nn.Parameter(torch.tensor(1.0)) self.eps = nn.Parameter(torch.tensor(0.0001)) From 1b2c46be5280b64b3801945e72acc1e2b8971241 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 19:16:58 +0800 Subject: [PATCH 0882/1191] Change eps in SequenceNorm to offset; change how clamp-to-minimum is used in rms_norm, this is bug fix. --- egs/librispeech/ASR/zipformer/scaling.py | 42 ++++++++++++------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e9ceaad98a..f7dd7c1e3e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,16 +333,16 @@ def backward(ctx, x_grad, *args): -def _sequence_norm(x: Tensor, eps: Tensor, scale: Tensor, mask: Optional[Tensor]): +def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): if mask is None: - scales = 1.0 / ((x ** 2).mean(dim=(0, 2), keepdim=True) + eps * eps).clamp(min=1.0e-05).sqrt() + scales = 1.0 / (x ** 2).mean(dim=(0, 2), keepdim=True).sqrt() else: mask = (~mask).to(torch.float).t().unsqueeze(-1) x = x * mask num_frames = mask.sum(dim=0) - scales = ((num_frames / ((x ** 2).mean(dim=2, keepdim=True).sum(dim=0))) + eps * eps).clamp(min=1.0e-05).sqrt() + scales = (num_frames / ((x ** 2).mean(dim=2, keepdim=True).sum(dim=0))).sqrt() - return x * (scale * scales) + return x * ((scale * scales) + offset) class SequenceNormFunction(torch.autograd.Function): @@ -351,30 +351,30 @@ def forward( ctx, x: Tensor, scale: Tensor, - eps: Tensor, + offset: Tensor, mask: Optional[Tensor], ) -> Tensor: - ctx.save_for_backward(x, eps, scale) + ctx.save_for_backward(x, offset, scale) ctx.mask = mask - return _sequence_norm(x, eps, scale, mask) + return _sequence_norm(x, offset, scale, mask) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, eps, scale = ctx.saved_tensors + x, offset, scale = ctx.saved_tensors mask = ctx.mask with torch.amp.autocast('cuda', enabled=False): - x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - x, eps, scale = x.detach(), eps.detach(), scale.detach() + x, offset, scale = x.to(torch.float32), offset.to(torch.float32), scale.to(torch.float32) + x, offset, scale = x.detach(), offset.detach(), scale.detach() x.requires_grad = True scale.requires_grad = True - eps.requires_grad = True + offset.requires_grad = True with torch.enable_grad(): - ans = _sequence_norm(x, eps, scale, ctx.mask) + ans = _sequence_norm(x, offset, scale, ctx.mask) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -382,7 +382,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(eps.grad), c(scale.grad), None + return x.grad, c(offset.grad), c(scale.grad), None class SequenceNorm(torch.nn.Module): @@ -391,14 +391,14 @@ class SequenceNorm(torch.nn.Module): as well as the channels; and a padding mask is used for irregular length sequences (actually, the mask is applied multiplicatively as well.) - There is also a learnable scalar scale and a learnable "eps" value. + There is also a learnable scalar scale and a learnable "offset" value. """ def __init__( self, ) -> None: super(SequenceNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(1.0)) - self.eps = nn.Parameter(torch.tensor(0.0001)) + self.offset = nn.Parameter(torch.tensor(0.0001)) self.name = None @@ -409,22 +409,22 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # mask: bool, (batch_size, seq_len) if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _sequence_norm(x, self.eps, self.scale, mask) + return _sequence_norm(x, self.offset, self.scale, mask) scale = limit_param_value( self.scale, min=0.05, max=1.0, training=self.training) - eps = limit_param_value( - self.eps, min=0.0, max=10.0, training=self.training) + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) ans = SequenceNormFunction.apply( - x, eps, scale, mask, + x, offset, scale, mask, ) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, eps={self.eps.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") return ans @@ -433,7 +433,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) - scales = (scale / x_sq).clamp(min=1.0e-05).sqrt() + scales = (scale / x_sq.clamp(min=1.0e-20)).sqrt() return x * scales From 120fb095ea52588ba460df6801b297fc86561971 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 19:29:22 +0800 Subject: [PATCH 0883/1191] Change default scale back to 0.5. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index f7dd7c1e3e..b668f52b89 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -397,7 +397,7 @@ def __init__( self, ) -> None: super(SequenceNorm, self).__init__() - self.scale = nn.Parameter(torch.tensor(1.0)) + self.scale = nn.Parameter(torch.tensor(0.5)) self.offset = nn.Parameter(torch.tensor(0.0001)) From 275922d26ce606f6446d55d07e57b4c0f4b842f6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 19:46:07 +0800 Subject: [PATCH 0884/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index b668f52b89..d1121fc875 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -433,7 +433,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) - scales = (scale / x_sq.clamp(min=1.0e-20)).sqrt() + scales = scale / x_sq.clamp(min=1.0e-20).sqrt() return x * scales From 471ea4781f65b27881888aeb2df991fc506fe38b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 20:09:28 +0800 Subject: [PATCH 0885/1191] Remove some clamping. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d1121fc875..63332fd967 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -433,7 +433,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) - scales = scale / x_sq.clamp(min=1.0e-20).sqrt() + scales = scale / x_sq.sqrt() return x * scales From 3a8a25647ebce52d055595efbe5576cb08320714 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Feb 2026 20:19:57 +0800 Subject: [PATCH 0886/1191] Bug fix --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 63332fd967..109da4ad70 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -350,8 +350,8 @@ class SequenceNormFunction(torch.autograd.Function): def forward( ctx, x: Tensor, - scale: Tensor, offset: Tensor, + scale: Tensor, mask: Optional[Tensor], ) -> Tensor: ctx.save_for_backward(x, offset, scale) From f6f7037e5d4eb71f2a254efffa3152e16c773644 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Feb 2026 13:54:10 +0800 Subject: [PATCH 0887/1191] Revert "Remove final RmsNorm on zipformer." This reverts commit 9985a618f2c4937384887c3e9f1394edaea57c09. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index cfbeb41fed..7859658d66 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -170,6 +170,8 @@ def _to_tuple(x): self.encoders = nn.ModuleList(encoders) + self.out_norm = RmsNorm() + def get_chunk_info(self) -> Tuple[int, int]: """ @@ -272,6 +274,8 @@ def forward( if od > 1: x_lens = (x_lens + od - 1) // od + x = self.out_norm(x) + return x, x_lens def _get_attn_mask( From 6d16da78c33fdd7066fcb99da88fc005c6fc7a0a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Feb 2026 13:58:00 +0800 Subject: [PATCH 0888/1191] Increase min of final residual_scale from .25 to .5; increase max of scale in SequenceNorm from 1 to 2. --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 109da4ad70..ecd41a236f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -412,7 +412,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: return _sequence_norm(x, self.offset, self.scale, mask) scale = limit_param_value( - self.scale, min=0.05, max=1.0, training=self.training) + self.scale, min=0.05, max=2.0, training=self.training) offset = limit_param_value( self.offset, min=0.0, max=10.0, training=self.training) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7859658d66..628e67a117 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -796,7 +796,7 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.0 if i + 1 < num_layers else 0.25, + min=0.0 if i + 1 < num_layers else 0.5, max=1.0) src_with_bypass = src_with_bypass + residual_scale * src From eda290c0696ef55384e0760348d8bc6a2b11c591 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Feb 2026 12:06:38 +0800 Subject: [PATCH 0889/1191] Reduce num layers from 6,8,20,8 to 6,8,16,8. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 52d7c207b4..c4bf963001 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -171,7 +171,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,20,8", + default="6,8,16,8", help="Number of zipformer encoder layers per stack, comma separated.", ) From de9f0b71d7902cfed4c64d0f3185272a7e1e1f91 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Feb 2026 12:19:44 +0800 Subject: [PATCH 0890/1191] Remove LR warmup and set scale_limits to (1.0, 4.0). --- egs/librispeech/ASR/zapformer/train.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index c4bf963001..d8cdc1af8f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1344,20 +1344,15 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, wd=0.15, + scale_limits=(1.0, 4.0), ) - warmup_steps = 2000 # hardcode batches per epoch for now. total_steps = 4550 * params.num_epochs - warmup_start = 0.25 def lr_lambda(current_step): - if current_step < warmup_steps: - # Linear warm-up - return warmup_start + (1.0 - warmup_start) * current_step / warmup_steps - else: - # Cosine annealing - progress = (current_step - warmup_steps) / (total_steps - warmup_steps) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + # Cosine annealing + progress = current_step / total_steps + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) scheduler = LambdaLR(optimizer, lr_lambda) From f6438e419aa1fba13aa7f685a4d039c6f117126f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Feb 2026 15:50:56 +0800 Subject: [PATCH 0891/1191] Simplify amount of x3 that is subtracted. --- egs/librispeech/ASR/zipformer/optim.py | 97 ++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3860bf9e20..0a645d5631 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -152,6 +152,61 @@ def base_step(group, state, grad): return grad / denom +def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is interpreted as a batch of matrices. + assert x.ndim >= 3 + + + if x.ndim > 3: + # each tensor in the batch has more than two dimensions. + # reshape to be like a batch of matrices. + # note: x.shape[0] is batch dimension. + if x.shape[1] > x.shape[-1]: + xr = x.reshape(x.shape[0], x.shape[1], -1) + else: + xr = x.reshape(x.shape[0], -1, x.shape[-1]) + compute_prod3_inplace(xr) + if not xr.untyped_storage() is x.untyped_storage(): + x[:] = xr.reshape(*x.shape) + return + if x.shape[1] > x.shape[2]: + xr = x.permute(0, 2, 1) + compute_prod3_inplace(xr) + if not xr.untyped_storage() is x.untyped_storage(): + x[:] = xr.permute(0, 2, 1) + return + + # avoid matrix multiplies by any dimensions that are too large. + max_dim = 1024 + if x.shape[1] > max_dim: + n = x.shape[1] + for divisor in range(2, 100): + if n % divisor == 0 and n // divisor <= max_dim: + xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) + compute_prod3_inplace(xr) + if not xr.untyped_storage() is x.untyped_storage(): + x[:] = xr.reshape(*x.shape) + return + # if no divisor worked, just continue. + + (batch_size, rows, cols) = x.shape # and rows <= cols + + x2 = torch.matmul(x, x.permute(0, 2, 1)) / max(rows, cols) + x3 = torch.matmul(x2, x) + + x[:] = x3 + + + + +def compute_prod3(x): + # computes matrix-matrix-matrix product of batch of matrices x, with reshaping if necessary; + # first divides x by max(num_rows, num_cols) so its a kind of normalized product. + x = x.clone() + compute_prod3_inplace(x) + return x + + + def scale_by(x, beta1): # This is similar in efffect @@ -240,15 +295,16 @@ def scale_by(x, beta1): post_scale = target_ratio ** 0.5 # post-scaling on x, after applying alpha. - x.add_(x3 * alpha[:, None, None], alpha=-1) + x3 = x3 * alpha[:, None, None] + + x.add_(x3, alpha=-1) x *= post_scale[:, None, None] - if False: - print(f"alpha={alpha}, scale={scale * (1-beta1)}") - dot_prod1 = (x * x).sum() - dot_prod2 = (x * x3).sum() * alpha - print(f"dot_prod1={dot_prod1}, dot_prod2={dot_prod2}") + if random.random() < 0.0001: + dot_prod1 = (x * x).sum(dim=(1, 2)) + dot_prod2 = (x * x3).sum(dim=(1, 2)) + logging.info(f"shape={x.shape}, beta1={beta1}, alpha={alpha}, alpha/(((1-beta1)**2)/dim)={alpha/(((1-beta1)**2)/max(rows,cols))}, post_scale={post_scale}, dot_prod_ratio={dot_prod2/dot_prod1}") @@ -272,11 +328,28 @@ def momentum_step(group, state, grad): state["delta"] = stored_delta + + def min_sum_scale(x, y): + # returns the scale alpha such that (x + alpha y) is minimized. x and y have + # the same shape and the shape of alpha is (x.shape[0], 1, 1, ...). + assert x.ndim > 1 + dims = list(range(1, x.ndim)) + xx = (x ** 2).sum(dim=dims, keepdim=True) + yy = (y ** 2).sum(dim=dims, keepdim=True) + xy = (y * x).sum(dim=dims, keepdim=True) + # sum square of x + alpha y is xx + alpha^2 yy + 2 alpha xy + # d/dalpha[that] = 2 alpha yy + 2 xy + # alpha = xy / yy + return -xy / (yy + eps) + stored_delta.add_(delta) - if step % 4 == 0: - stored_delta.mul_(beta1) - else: - scale_by(stored_delta, beta1) + stored_delta.mul_(beta1) + if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): + eta = 1.0 # scale on subtraction of x3. + x3 = compute_prod3(stored_delta) # actually 3rd power of stored_delta divided by max(rows, cols). + update_scale = (-eta * (1 - beta1)**2) + update_scale = min_sum_scale(stored_delta, x3).clamp(min=update_scale) + stored_delta.add_(x3 * update_scale) ans = (((1-direct) * (1-beta1)) * stored_delta) + (direct * delta) # OK, now divide ans by its rms so it has unit rms @@ -896,9 +969,9 @@ def _test_transformed_adam(hidden_dim: int): lr = 0.001 if test == 0: - optim = TransformedAdam(m.named_parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.95) + optim = TransformedAdam(m.named_parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.99) elif test == 1: - optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.95) + optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.99) num_epochs = 180 From a3a753e0a2834c27afcfaa0914e09459d33607b8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Feb 2026 16:28:00 +0800 Subject: [PATCH 0892/1191] Only do half the decay. --- egs/librispeech/ASR/zipformer/optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 0a645d5631..43d455ccb3 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -343,13 +343,16 @@ def min_sum_scale(x, y): return -xy / (yy + eps) stored_delta.add_(delta) - stored_delta.mul_(beta1) if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): + stored_delta.mul_(0.5 * (beta1 + 1)) eta = 1.0 # scale on subtraction of x3. x3 = compute_prod3(stored_delta) # actually 3rd power of stored_delta divided by max(rows, cols). update_scale = (-eta * (1 - beta1)**2) update_scale = min_sum_scale(stored_delta, x3).clamp(min=update_scale) stored_delta.add_(x3 * update_scale) + else: + stored_delta.mul_(beta1) + ans = (((1-direct) * (1-beta1)) * stored_delta) + (direct * delta) # OK, now divide ans by its rms so it has unit rms From 4c7264da37099bdbc9b0e423ae6736563cd6688b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Feb 2026 22:01:27 +0800 Subject: [PATCH 0893/1191] Increase weight decay from .15 to .3, fixing issue with 2052. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d8cdc1af8f..74303b8445 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1343,7 +1343,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.15, + wd=0.3, scale_limits=(1.0, 4.0), ) From f25ff06d24346662f55e47c733a28866963521d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 Feb 2026 22:01:27 +0800 Subject: [PATCH 0894/1191] Increase weight decay from .15 to .3, fixing issue with 2052. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d8cdc1af8f..74303b8445 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1343,7 +1343,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.15, + wd=0.3, scale_limits=(1.0, 4.0), ) From 97ffb3b890fdb847326f217db53d6896b0d299b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 16 Feb 2026 14:39:28 +0800 Subject: [PATCH 0895/1191] Reduce the linear decay rate from one half to one quarter of the beta1-determined rate. --- egs/librispeech/ASR/zipformer/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 43d455ccb3..7e2d5d34bc 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -344,7 +344,9 @@ def min_sum_scale(x, y): stored_delta.add_(delta) if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): - stored_delta.mul_(0.5 * (beta1 + 1)) + # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. + # this should be configurable. + stored_delta.mul_(0.25 * beta1 + 0.75) eta = 1.0 # scale on subtraction of x3. x3 = compute_prod3(stored_delta) # actually 3rd power of stored_delta divided by max(rows, cols). update_scale = (-eta * (1 - beta1)**2) From 157c664b7666481b426419521bb2896c0a33ab1a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 7 Feb 2026 13:18:40 +0800 Subject: [PATCH 0896/1191] Do not zero out masked frames in sequence_norm. # Conflicts: # egs/librispeech/ASR/zipformer/scaling.py --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index ecd41a236f..fc3f91af97 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -338,9 +338,9 @@ def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tens scales = 1.0 / (x ** 2).mean(dim=(0, 2), keepdim=True).sqrt() else: mask = (~mask).to(torch.float).t().unsqueeze(-1) - x = x * mask + xm = x * mask num_frames = mask.sum(dim=0) - scales = (num_frames / ((x ** 2).mean(dim=2, keepdim=True).sum(dim=0))).sqrt() + scales = (num_frames / ((xm ** 2).mean(dim=2, keepdim=True).sum(dim=0))).sqrt() return x * ((scale * scales) + offset) From 230e5b1615eed031703f649055ff9a172e9059c3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 17:03:10 +0800 Subject: [PATCH 0897/1191] Improve code for cosine LR schedule, making it less critical to estimate batches-per-epoch, and set a minimum factor of 0.1 by default. --- .../ASR/zapformer/combined_scheduler.py | 131 ++++++++++++++++++ egs/librispeech/ASR/zapformer/train.py | 27 ++-- 2 files changed, 150 insertions(+), 8 deletions(-) create mode 100644 egs/librispeech/ASR/zapformer/combined_scheduler.py diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py new file mode 100644 index 0000000000..4be4c16991 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -0,0 +1,131 @@ +import torch +from torch import Tensor +from torch.optim import Optimizer +from typing import List +import math +import logging + +class CombinedLRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch; it estimates the "progress" for you. + """ + def __init__(self, + optimizer: Optimizer, + batches_per_epoch: int, + num_epochs: int, + verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.batches_per_epoch = batches_per_epoch + self.num_epochs = num_epochs # the number of epochs we plan to train for. + + self.epoch = -1 + self.batch = -1 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + # Caution: storing batches_per_epoch with the state might not necessarily be what you want, + # it's good for interrupted training runs only as long as you continue to train with the + # same world-size. + "batches_per_epoch": self.batches_per_epoch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def set_batch(self, batch: int): + # set the within-epoch batch index. + self.batch = batch + self._set_lrs() + + def set_epoch(self, epoch: int): + assert epoch > 0 and epoch <= self.num_epochs # Epoch numbers are assumed to be be 1-based indexes. + if epoch == self.epoch + 1 and self.batch > 0: + logging.info(f"Overriding batches_per_epoch from {self.batches_per_epoch} to {self.batch} based on observed batch count.") + self.batches_per_epoch = self.batch + + self.epoch = epoch + self._set_lrs() + + def get_progress(self): + if self.epoch <= 0: + return 0.0 + else: + assert self.epoch <= self.num_epochs + assert self.batches_per_epoch > 0 + whole_epoch_progress = (self.epoch - 1) / self.num_epochs + batch = self.batch + if batch < 0: + partial_epoch_progress = 0 + else: + partial_epoch_progress = min(1.0, batch / self.batches_per_epoch) / self.num_epochs + return whole_epoch_progress + partial_epoch_progress + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.warning( + f"Epoch={self.epoch}, batch={self.batch}, num_epochs={self.num_epochs}, batches_per_epoch={self.batches_per_epoch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class CosineLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.1, + **kwargs): + super().__init__(*args, **kwargs) + self.min_factor = min_factor + + def get_lr(self): + progress = self.get_progress() + factor = max(self.min_factor, 0.5 * (1.0 + math.cos(math.pi * progress))) + return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 74303b8445..db6eb41abe 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,6 +76,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import TransformedAdam +from combined_scheduler import CombinedLRScheduler, CosineLRScheduler from torch.optim.lr_scheduler import LambdaLR from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -105,8 +106,6 @@ str2bool, ) -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - def get_adjusted_batch_count(params: AttributeDict) -> float: # returns the number of batches we would have used so far if we had used the reference @@ -369,6 +368,15 @@ def get_parser(): help="Number of epochs to train.", ) + parser.add_argument( + "--batches-per-epoch", + type=int, + default=4550, + help="Assumed number of batches per epoch for purposes of setting learning rate; only " + "makes a difference during the first batch, after which an observed value is used.." + ) + + parser.add_argument( "--start-epoch", type=int, @@ -759,7 +767,7 @@ def load_checkpoint_if_available( model: nn.Module, model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, + scheduler: Optional[CombinedLRScheduler] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -825,7 +833,7 @@ def save_checkpoint( model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, + scheduler: Optional[CombinedLRScheduler] = None, sampler: Optional[CutSampler] = None, scaler: Optional[GradScaler] = None, rank: int = 0, @@ -1030,7 +1038,7 @@ def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, + scheduler: CombinedLRScheduler, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -1125,8 +1133,7 @@ def save_bad_model(suffix: str = ""): # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - scheduler.step() - + scheduler.set_batch(batch_idx) # sets batch-count within the epoch, and sets the LRs. scaler.step(optimizer) scaler.update() optimizer.zero_grad() @@ -1354,7 +1361,10 @@ def lr_lambda(current_step): progress = current_step / total_steps return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - scheduler = LambdaLR(optimizer, lr_lambda) + scheduler = CosineLRScheduler(optimizer, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs, + verbose=True) if checkpoints and "optimizer" in checkpoints: @@ -1491,6 +1501,7 @@ def remove_short_and_long_utt(c: Cut): tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) params.cur_epoch = epoch + scheduler.set_epoch(epoch) train_one_epoch( params=params, From fa031b2b9088f408eb2fc31c7cf5fe1bc5a268b1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 18:09:36 +0800 Subject: [PATCH 0898/1191] Reduce central num layers from 16 to 14. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index db6eb41abe..63855e3265 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -170,7 +170,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,16,8", + default="6,8,14,8", help="Number of zipformer encoder layers per stack, comma separated.", ) From d06def70bf40addbdec9d364c001aa22214d88be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 18:20:31 +0800 Subject: [PATCH 0899/1191] Remove verbose=True --- egs/librispeech/ASR/zapformer/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index db6eb41abe..8fb8131520 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1363,9 +1363,7 @@ def lr_lambda(current_step): scheduler = CosineLRScheduler(optimizer, batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs, - verbose=True) - + num_epochs=params.num_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From dd86dd7898164e79a28458fc8137bbb83dd52e1b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 19:20:21 +0800 Subject: [PATCH 0900/1191] Make optimizer use 5th, not 3rd, power of matrix for decay. --- egs/librispeech/ASR/zipformer/optim.py | 30 ++++++++++++++------------ 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7e2d5d34bc..18a665aba3 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -152,7 +152,7 @@ def base_step(group, state, grad): return grad / denom -def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is interpreted as a batch of matrices. +def compute_prod5_inplace(x): # replaces x with x^3 / max(rows, cols), x is interpreted as a batch of matrices. assert x.ndim >= 3 @@ -164,13 +164,13 @@ def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is inte xr = x.reshape(x.shape[0], x.shape[1], -1) else: xr = x.reshape(x.shape[0], -1, x.shape[-1]) - compute_prod3_inplace(xr) + compute_prod5_inplace(xr) if not xr.untyped_storage() is x.untyped_storage(): x[:] = xr.reshape(*x.shape) return if x.shape[1] > x.shape[2]: xr = x.permute(0, 2, 1) - compute_prod3_inplace(xr) + compute_prod5_inplace(xr) if not xr.untyped_storage() is x.untyped_storage(): x[:] = xr.permute(0, 2, 1) return @@ -182,7 +182,7 @@ def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is inte for divisor in range(2, 100): if n % divisor == 0 and n // divisor <= max_dim: xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) - compute_prod3_inplace(xr) + compute_prod5_inplace(xr) if not xr.untyped_storage() is x.untyped_storage(): x[:] = xr.reshape(*x.shape) return @@ -191,18 +191,19 @@ def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is inte (batch_size, rows, cols) = x.shape # and rows <= cols x2 = torch.matmul(x, x.permute(0, 2, 1)) / max(rows, cols) - x3 = torch.matmul(x2, x) + x4 = torch.matmul(x2, x2) + x5 = torch.matmul(x4, x) - x[:] = x3 + x[:] = x5 -def compute_prod3(x): - # computes matrix-matrix-matrix product of batch of matrices x, with reshaping if necessary; - # first divides x by max(num_rows, num_cols) so its a kind of normalized product. +def compute_prod5(x): + # computes matrix-matrix-matrix-matrix-matrix product of batch of matrices x, with reshaping if necessary; + # first divides x by max(num_rows, num_cols)^2 so its a kind of normalized 5th-product. x = x.clone() - compute_prod3_inplace(x) + compute_prod5_inplace(x) return x @@ -348,10 +349,11 @@ def min_sum_scale(x, y): # this should be configurable. stored_delta.mul_(0.25 * beta1 + 0.75) eta = 1.0 # scale on subtraction of x3. - x3 = compute_prod3(stored_delta) # actually 3rd power of stored_delta divided by max(rows, cols). - update_scale = (-eta * (1 - beta1)**2) - update_scale = min_sum_scale(stored_delta, x3).clamp(min=update_scale) - stored_delta.add_(x3 * update_scale) + update_scale = (eta * (1 - beta1)**3) + x5 = stored_delta * (update_scale ** 0.2) + compute_prod5_inplace(x5) # actually computes 5rd power of its arg divided by max(rows, cols)**2 + alpha = min_sum_scale(stored_delta, x5).clamp(min=-1) + stored_delta.add_(x5 * alpha) else: stored_delta.mul_(beta1) From bfd278f0397e7932e089757450b549dd4bd17e1a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 21:50:29 +0800 Subject: [PATCH 0901/1191] Code refactoring to make self-attention done in one class, should make no difference to results. --- egs/librispeech/ASR/zapformer/train.py | 8 ++ egs/librispeech/ASR/zipformer/zipformer.py | 122 ++++++++++++++------- 2 files changed, 88 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 8fb8131520..e2fab72071 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -230,6 +230,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Position encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + parser.add_argument( "--conv-params", type=str, @@ -675,6 +682,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: encoder_dim=lookup(params, "encoder_dim"), query_head_dim=lookup(params, "query_head_dim"), value_head_dim=lookup(params, "value_head_dim"), + pos_head_dim=lookup(params, "pos_head_dim"), num_heads=lookup(params, "num_heads"), feedforward_multiple=lookup(params, "feedforward_multiple"), conv_params=lookup(params, "conv_params"), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 628e67a117..fd17c03e54 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -74,6 +74,7 @@ class Zipformer2(EncoderInterface): query_head_dim (int or Tuple[int]): dimension of query and key per attention head: per stack, if a tuple.. value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of position-embedding in each attention head num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules @@ -99,6 +100,7 @@ def __init__( num_encoder_layers: Union[int, Tuple[int]] = 4, query_head_dim: Union[int, Tuple[int]] = 64, value_head_dim: Union[int, Tuple[int]] = 12, + pos_head_dim: Union[int, Tuple[int]] = 4, num_heads: Union[int, Tuple[int]] = 8, feedforward_multiple: Union[int, Tuple[int]] = 4, conv_params: Union[int, Tuple[int]] = 31, @@ -127,6 +129,7 @@ def _to_tuple(x): self.num_encoder_layers = num_encoder_layers self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + self.pos_head_dim = pos_head_dim = _to_tuple(pos_head_dim) self.num_heads = num_heads = _to_tuple(num_heads) feedforward_multiple = _to_tuple(feedforward_multiple) self.conv_params = conv_params = _to_tuple(conv_params) @@ -153,6 +156,7 @@ def _to_tuple(x): num_heads=num_heads[i], query_head_dim=query_head_dim[i], value_head_dim=value_head_dim[i], + pos_head_dim=pos_head_dim[i], feedforward_multiple=feedforward_multiple[i], conv_params=conv_params[i], causal=causal, @@ -508,6 +512,7 @@ def __init__( num_heads: int, query_head_dim: int, value_head_dim: int, + pos_head_dim: int, feedforward_multiple: int, conv_params: int, causal: bool = False, @@ -516,21 +521,19 @@ def __init__( self.embed_dim = embed_dim self.name = None # will be set from training loop - #self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) - self.self_attn_weights = MultiheadAttentionWeights( + self.self_attn = MultiheadRelPosGatedSelfAttention( embed_dim, num_heads=num_heads, query_head_dim=query_head_dim, + value_head_dim=value_head_dim, + pos_head_dim=pos_head_dim, ) - self.self_attn = GatedSelfAttention(embed_dim, num_heads, value_head_dim) - feedforward_dim = embed_dim * feedforward_multiple self.feed_forward1 = FeedforwardModule(embed_dim, feedforward_dim) @@ -573,24 +576,19 @@ def forward( 2. * aux_loss_scale, mask=src_key_padding_mask), None) - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - aux_loss_scale=0.1 * aux_loss_scale, - ) + src_pre_ff1 = src src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn(src, attn_weights, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + # may try changing src_pre_ff1 to src or vice versa. + src = src + self.self_attn(src_pre_ff1, src, attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale) src = src + self.conv_module(3. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - #residual_scale = limit_param_value(self.residual_scale, min=0.25, max=0.75) residual_scale = 0.25 offset = (src - src_orig) * residual_scale @@ -1027,16 +1025,12 @@ def forward( return x_out.type_as(x) -class MultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with additive relative-position - scores that are kept separate from the regular scores. - - relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. +class MultiheadRelPosGatedSelfAttention(nn.Module): + r""" + Module that computes multi-head attention weights with additive relative-position + scores that are kept separate from the regular scores. The values have gating. + An RMSNorm module is used to pre-normalize the input embedding only as it is + input to the queries and keys, not the values. Args: embed_dim: number of channels at the input to this module, e.g. 256 @@ -1048,7 +1042,8 @@ def __init__( embed_dim: int, num_heads: int, query_head_dim: int, - pos_dim: int = 4, + pos_head_dim: int = 4, + value_head_dim: int = 12, dropout: float = 0.0, ) -> None: super().__init__() @@ -1061,35 +1056,52 @@ def __init__( self.in_norm = RmsNorm() key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5 that has been used in previous forms of attention, # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( + self.qkp_in_proj = ScaledLinear( embed_dim, in_proj_dim, bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - #self.rope = RotaryPositionalEmbeddings(query_head_dim) # use default max_seq_len=4096, base=10000 - - self.rel_pos = RelPosScores(num_heads, pos_dim, num_freqs=64) + self.rel_pos = RelPosScores(num_heads, pos_head_dim, num_freqs=64) self.copy_query = Identity() self.copy_pos_query = Identity() + # value and gating in_proj. + self.vg_in_proj = ScaledLinear(embed_dim, 2 * num_heads * value_head_dim, + initial_scale=0.1, bias=True) + + + self.copy_v = nn.Identity() # diagnostics. + self.sigmoid = nn.Sigmoid() + + # out proj for the value times gating. + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 + ) + + + def forward( self, - x: Tensor, + x_qkp: Tensor, + x_vg: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, ) -> Tensor: r""" Args: - x: input of shape (seq_len, batch_size, embed_dim) + x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, + keys and positions. + x_vg: input of shape (seq_len, batch_size, embed_dim), that is used for the values + and gates. May be the same as x_qk. key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), @@ -1101,17 +1113,17 @@ def forward( """ query_head_dim = self.query_head_dim num_heads = self.num_heads - x = self.in_norm(x) - x = self.in_proj(x) + x_qkp = self.in_norm(x_qkp) + x_qkp = self.qkp_in_proj(x_qkp) - seq_len, batch_size, _ = x.shape + seq_len, batch_size, _ = x_qkp.shape query_dim = query_head_dim * num_heads # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - p = x[..., 2 * query_dim:] + q = x_qkp[..., 0:query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] + p = x_qkp[..., 2 * query_dim:] q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. @@ -1156,7 +1168,8 @@ def forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. - attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, 0.1 * aux_loss_scale, + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, + 0.01 * aux_loss_scale, # increase? key_padding_mask, self.name) @@ -1172,13 +1185,38 @@ def forward( elif random.random() < 0.001: self._print_attn_entropy(attn_weights) + # note: self.dropout is normally 0.0. attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) - return attn_weights + v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) + v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + v = self.copy_v(v) + value_head_dim = v.shape[-1] + # now v: (num_heads, batch_size, seq_len, value_head_dim) - def streaming_forward( + # todo: see whether there is benefit in overriding matmul + v = torch.matmul(attn_weights, v) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + v = ( + v.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + + if self.training: + # don't let the sigmoid values get too extreme, limit to -2..2. + g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + v = v * self.sigmoid(g) + v = self.out_proj(v) + return v + + def streaming_forward( # TODO: fix and test, needs to do value and gating stuff. self, x: Tensor, cached_key: Tensor, From 6863963c0e859e4825fd2588b6f717e886400faf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 22:15:33 +0800 Subject: [PATCH 0902/1191] Fix loss scale for penalty in self attention. --- egs/librispeech/ASR/zipformer/zipformer.py | 203 +++++++-------------- 1 file changed, 61 insertions(+), 142 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index fd17c03e54..acae9fbbb9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -583,7 +583,7 @@ def forward( # may try changing src_pre_ff1 to src or vice versa. src = src + self.self_attn(src_pre_ff1, src, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, - aux_loss_scale=aux_loss_scale) + aux_loss_scale=0.1 * aux_loss_scale) src = src + self.conv_module(3. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) @@ -1169,7 +1169,7 @@ def forward( if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, - 0.01 * aux_loss_scale, # increase? + 0.1 * aux_loss_scale, key_padding_mask, self.name) @@ -1216,7 +1216,8 @@ def forward( v = self.out_proj(v) return v - def streaming_forward( # TODO: fix and test, needs to do value and gating stuff. + def streaming_forward_weights( # TODO: fix and test, needs to be combined with value and gating stuff, + # see streaming_forward_vg which I took from the old class. self, x: Tensor, cached_key: Tensor, @@ -1294,6 +1295,63 @@ def streaming_forward( # TODO: fix and test, needs to do value and gating stuff return attn_weights, cached_key + def streaming_forward_vg( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + def _print_attn_entropy(self, attn_weights: Tensor): # attn_weights: (num_heads, batch_size, seq_len, seq_len) (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape @@ -1364,145 +1422,6 @@ def backward( -class GatedSelfAttention(nn.Module): - """ - Self-attention module with sigmoid gating. This one works with already-computed attention - weights, e.g. as computed by MultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = ScaledLinear(embed_dim, 2 * num_heads * value_head_dim, - initial_scale=0.1, bias=True) - - - self.copy_x = nn.Identity() # diagnostics. - self.sigmoid = nn.Sigmoid() - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 - ) - - f = max(1.0, embed_dim / (num_heads * value_head_dim)) - - - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - aux_loss_scale: float = 0.0, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - src_key_padding_mask: optional Tensor of shape (batch_size, src_seq_len); only - used for the cosine similarity loss, during training. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, 2 * num_heads * value_head_dim) - x, s = x.chunk(2, dim=-1) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - x = self.copy_x(x) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # x: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - - if self.training: - # don't let the sigmoid values get too extreme, limit to -2..2. - s = penalize_abs_values_gt(s, 2, penalty=0.02*aux_loss_scale) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = x * self.sigmoid(s) - x = self.out_proj(x) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape[0], - left_context_len, - ) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - class FeedforwardModule(nn.Module): """Feedforward module in Zipformer2 model.""" From d52446751712525a73c89b89c8bd2b775b500b3a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Feb 2026 22:42:58 +0800 Subject: [PATCH 0903/1191] Reduce query-head-dim from 128 to 64. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index e2fab72071..9ae6f41e0d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -219,7 +219,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--query-head-dim", type=str, - default="128", + default="64", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) From bef9f3eaed4d37fbd48d8e7639d3c3e06d22dfe2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 18 Feb 2026 13:45:48 +0800 Subject: [PATCH 0904/1191] Introduce factor of 0.5 to prevent overshooting; helps initial convergence. --- egs/librispeech/ASR/zipformer/optim.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 18a665aba3..b11ade076d 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -335,7 +335,6 @@ def min_sum_scale(x, y): # the same shape and the shape of alpha is (x.shape[0], 1, 1, ...). assert x.ndim > 1 dims = list(range(1, x.ndim)) - xx = (x ** 2).sum(dim=dims, keepdim=True) yy = (y ** 2).sum(dim=dims, keepdim=True) xy = (y * x).sum(dim=dims, keepdim=True) # sum square of x + alpha y is xx + alpha^2 yy + 2 alpha xy @@ -352,7 +351,12 @@ def min_sum_scale(x, y): update_scale = (eta * (1 - beta1)**3) x5 = stored_delta * (update_scale ** 0.2) compute_prod5_inplace(x5) # actually computes 5rd power of its arg divided by max(rows, cols)**2 - alpha = min_sum_scale(stored_delta, x5).clamp(min=-1) + # the factor of 0.5 says we only want to go, at most, half the way to the point which + # would give us the minimum 'x'. this is to prevent the largest eigs overshooting + # and having the direction change sign, in a situation where we are not dominated by + # the largest singular value; or to prevent the largest singular value from going to + # zero if it does dominate. + alpha = (0.5 * min_sum_scale(stored_delta, x5)).clamp(min=-1) stored_delta.add_(x5 * alpha) else: stored_delta.mul_(beta1) From a96c73c9edea57ed57941a8073b5fb62e56b2a92 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 20 Feb 2026 14:28:41 +0800 Subject: [PATCH 0905/1191] Increase lr from .8e-03 to 1.2e-03 and decrease weight decay from .3 to .2, keeping the product the same; should increase floor param rms. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 54da6e10fb..06c05912ff 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.3, + wd=0.2, scale_limits=(1.0, 4.0), ) From a0a6b8786e8f6ae906df0e0b4d86b09de5aad718 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 20 Feb 2026 17:47:09 +0800 Subject: [PATCH 0906/1191] Add factor of 1-linear_decay to 5th-order --- egs/librispeech/ASR/zipformer/optim.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b11ade076d..ebfc9214fc 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -346,7 +346,10 @@ def min_sum_scale(x, y): if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. # this should be configurable. - stored_delta.mul_(0.25 * beta1 + 0.75) + linear_decay = 0.25 + + + stored_delta.mul_(linear_decay * beta1 + (1 - linear_decay)) eta = 1.0 # scale on subtraction of x3. update_scale = (eta * (1 - beta1)**3) x5 = stored_delta * (update_scale ** 0.2) @@ -356,7 +359,7 @@ def min_sum_scale(x, y): # and having the direction change sign, in a situation where we are not dominated by # the largest singular value; or to prevent the largest singular value from going to # zero if it does dominate. - alpha = (0.5 * min_sum_scale(stored_delta, x5)).clamp(min=-1) + alpha = (0.5 * min_sum_scale(stored_delta, x5)).clamp(min=-(1 - linear_decay)) stored_delta.add_(x5 * alpha) else: stored_delta.mul_(beta1) From 22b087fee0148eee0801795579da0c8d095cc1e7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 20 Feb 2026 17:49:34 +0800 Subject: [PATCH 0907/1191] Increase wd from .2 to .25 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 06c05912ff..6fac64a4ff 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.2, + wd=0.25, scale_limits=(1.0, 4.0), ) From d718db8aed3b58ca1db280032814684e145b53b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Feb 2026 13:34:14 +0800 Subject: [PATCH 0908/1191] Reduce weight decay from .25 to .2. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6fac64a4ff..06c05912ff 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.25, + wd=0.2, scale_limits=(1.0, 4.0), ) From 38a649a533839816b2216c939d1a8cdc3659c535 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Feb 2026 16:54:17 +0800 Subject: [PATCH 0909/1191] Change power from 5 to 3; new, better-motivated formula for determining coefficient. --- egs/librispeech/ASR/zipformer/optim.py | 70 +++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b11ade076d..e5326df09f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -207,6 +207,60 @@ def compute_prod5(x): return x +def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is interpreted as a batch of matrices. + assert x.ndim >= 3 + + + if x.ndim > 3: + # each tensor in the batch has more than two dimensions. + # reshape to be like a batch of matrices. + # note: x.shape[0] is batch dimension. + if x.shape[1] > x.shape[-1]: + xr = x.reshape(x.shape[0], x.shape[1], -1) + else: + xr = x.reshape(x.shape[0], -1, x.shape[-1]) + compute_prod3_inplace(xr) + if not xr.untyped_storage() is x.untyped_storage(): + x[:] = xr.reshape(*x.shape) + return + if x.shape[1] > x.shape[2]: + xr = x.permute(0, 2, 1) + compute_prod3_inplace(xr) + if not xr.untyped_storage() is x.untyped_storage(): + x[:] = xr.permute(0, 2, 1) + return + + # avoid matrix multiplies by any dimensions that are too large. + max_dim = 1024 + if x.shape[1] > max_dim: + n = x.shape[1] + for divisor in range(2, 100): + if n % divisor == 0 and n // divisor <= max_dim: + xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) + compute_prod3_inplace(xr) + if not xr.untyped_storage() is x.untyped_storage(): + x[:] = xr.reshape(*x.shape) + return + # if no divisor worked, just continue. + + (batch_size, rows, cols) = x.shape # and rows <= cols + + x2 = torch.matmul(x, x.permute(0, 2, 1)) / max(rows, cols) + x3 = torch.matmul(x2, x) + + x[:] = x3 + + + + +def compute_prod3(x): + # computes matrix-matrix-matrix-matrix-matrix product of batch of matrices x, with reshaping if necessary; + # first divides x by max(num_rows, num_cols)^2 so its a kind of normalized 3rdproduct. + x = x.clone() + compute_prod3_inplace(x) + return x + + def scale_by(x, beta1): @@ -346,18 +400,20 @@ def min_sum_scale(x, y): if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. # this should be configurable. - stored_delta.mul_(0.25 * beta1 + 0.75) - eta = 1.0 # scale on subtraction of x3. - update_scale = (eta * (1 - beta1)**3) - x5 = stored_delta * (update_scale ** 0.2) - compute_prod5_inplace(x5) # actually computes 5rd power of its arg divided by max(rows, cols)**2 + + linear_decay_scale = 0.25 + + stored_delta.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) + excess_scale = 2.0 # approximately: the amount by which we let the singular values exceed the rms value they would have if the data were i.i.d. + x3 = stored_delta * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times + compute_prod3_inplace(x3) # actually computes 3rd power of its arg divided by max(rows, cols)**2 # the factor of 0.5 says we only want to go, at most, half the way to the point which # would give us the minimum 'x'. this is to prevent the largest eigs overshooting # and having the direction change sign, in a situation where we are not dominated by # the largest singular value; or to prevent the largest singular value from going to # zero if it does dominate. - alpha = (0.5 * min_sum_scale(stored_delta, x5)).clamp(min=-1) - stored_delta.add_(x5 * alpha) + alpha = (0.5 * min_sum_scale(stored_delta, x3)).clamp(min=-1) + stored_delta.add_(x3 * alpha) else: stored_delta.mul_(beta1) From fc40a72d4e2603b108035d21c3860b0ff1cd6784 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Feb 2026 17:06:13 +0800 Subject: [PATCH 0910/1191] Reduce weight decay from .3 to .2. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 54da6e10fb..06c05912ff 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.3, + wd=0.2, scale_limits=(1.0, 4.0), ) From 394d45c9f5b79f4bd9fd24451431356519c53691 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 22 Feb 2026 16:19:41 +0800 Subject: [PATCH 0911/1191] Decrease weight decay from .2 to .15 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 06c05912ff..608b10133b 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.2, + wd=0.15, scale_limits=(1.0, 4.0), ) From febd8e0cb1aab46d73302c7a35c4161d617699ee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Feb 2026 11:52:41 +0800 Subject: [PATCH 0912/1191] Increase excess_scale from 2 to 2.5. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index e5326df09f..66c8c1c4b8 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -404,7 +404,7 @@ def min_sum_scale(x, y): linear_decay_scale = 0.25 stored_delta.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) - excess_scale = 2.0 # approximately: the amount by which we let the singular values exceed the rms value they would have if the data were i.i.d. + excess_scale = 2.5 x3 = stored_delta * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times compute_prod3_inplace(x3) # actually computes 3rd power of its arg divided by max(rows, cols)**2 # the factor of 0.5 says we only want to go, at most, half the way to the point which From ec260a919a04ee033bf4e2bfdf44db5549c86ce8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Feb 2026 12:01:08 +0800 Subject: [PATCH 0913/1191] Revert wd to .2 which I had accidentally changed --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 608b10133b..06c05912ff 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.15, + wd=0.2, scale_limits=(1.0, 4.0), ) From 7ca0aea55bad8d52f695ca97bf3674791dac47a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Feb 2026 17:36:13 +0800 Subject: [PATCH 0914/1191] Change central number from 14 to 16. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 06c05912ff..d8e9b7e458 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -170,7 +170,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,14,8", + default="6,8,16,8", help="Number of zipformer encoder layers per stack, comma separated.", ) From 58dac55256276418e8045178d45cb1addac082c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Feb 2026 13:26:19 +0800 Subject: [PATCH 0915/1191] Make flooring of LR scheduler be done linearly not by floor --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 4be4c16991..7eb1bfaa8c 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -127,5 +127,6 @@ def __init__(self, def get_lr(self): progress = self.get_progress() - factor = max(self.min_factor, 0.5 * (1.0 + math.cos(math.pi * progress))) + factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + factor = self.min_factor + (1.0 - self.min_factor) * factor return [x * factor for x in self.base_lrs] From 295b57964ede0545efe27885f77a864fe650b777 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Feb 2026 14:17:11 +0800 Subject: [PATCH 0916/1191] Reduce correlation_limiter limit from .45 to .35, as in 2086 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index acae9fbbb9..8d013e5df1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -523,7 +523,7 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) - power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn = MultiheadRelPosGatedSelfAttention( From 7fb020376ab07bbc27567d13fe874b81d9c74730 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Feb 2026 17:48:57 +0800 Subject: [PATCH 0917/1191] Introduce ballast into SequenceNorm and implement causal mode --- egs/librispeech/ASR/zipformer/scaling.py | 79 +++++++++++++++------- egs/librispeech/ASR/zipformer/zipformer.py | 17 ++--- 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fc3f91af97..3bc82ca273 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,15 +333,32 @@ def backward(ctx, x_grad, *args): -def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): - if mask is None: - scales = 1.0 / (x ** 2).mean(dim=(0, 2), keepdim=True).sqrt() +# all arg tensors are scalars +def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor, causal: bool, mask: Optional[Tensor]): + stats = (x ** 2).mean(dim=2, keepdim=True) + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * ballast_frames.abs() + ballast = ballast_frames * (ballast_rms ** 2) + T = x.shape[0] # time + + if causal: + # no need for mask in causal mode. + stats = stats.cumsum(dim=0) + ballast + lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] else: - mask = (~mask).to(torch.float).t().unsqueeze(-1) - xm = x * mask - num_frames = mask.sum(dim=0) - scales = (num_frames / ((xm ** 2).mean(dim=2, keepdim=True).sum(dim=0))).sqrt() + if mask is None: + # no need for mask in causal mode. + stats = stats.sum(dim=0) + ballast + lengths = ballast_frames + T + else: + mask = (~mask).to(torch.float).t().unsqueeze(-1) + stats = stats * mask + stats = stats.sum(dim=0) + ballast + lengths = ballast_frames + mask.sum(dim=0) + scales = (lengths / stats).sqrt() # (T, batch_size, 1) if causal else (batch_size 1) + assert scales.shape == (T, x.shape[1], 1) if causal else (x.shape[1], 1) return x * ((scale * scales) + offset) @@ -352,29 +369,32 @@ def forward( x: Tensor, offset: Tensor, scale: Tensor, + ballast_rms: Tensor, + ballast_frames: Tensor, + causal: bool, mask: Optional[Tensor], ) -> Tensor: - ctx.save_for_backward(x, offset, scale) + ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) + ctx.causal = causal ctx.mask = mask - return _sequence_norm(x, offset, scale, mask) + return _sequence_norm(x, offset, scale, ballast_rms, ballast_frames, causal, mask) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - x, offset, scale = ctx.saved_tensors - mask = ctx.mask + x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors - with torch.amp.autocast('cuda', enabled=False): - x, offset, scale = x.to(torch.float32), offset.to(torch.float32), scale.to(torch.float32) - x, offset, scale = x.detach(), offset.detach(), scale.detach() - x.requires_grad = True - scale.requires_grad = True - offset.requires_grad = True + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + ballast_rms = ballast_rms.to(torch.float32).detach().requires_grad_() + ballast_frames = ballast_frames.to(torch.float32).detach().requires_grad_() with torch.enable_grad(): - ans = _sequence_norm(x, offset, scale, ctx.mask) + ans = _sequence_norm(x, offset, scale, ballast_rms, ballast_frames, ctx.causal, ctx.mask) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): @@ -382,7 +402,7 @@ def c(x): # in autocast mode. return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(offset.grad), c(scale.grad), None + return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad), None, None class SequenceNorm(torch.nn.Module): @@ -395,21 +415,26 @@ class SequenceNorm(torch.nn.Module): """ def __init__( self, + causal: bool, ) -> None: super(SequenceNorm, self).__init__() self.scale = nn.Parameter(torch.tensor(0.5)) self.offset = nn.Parameter(torch.tensor(0.0001)) - - + # ballast_mean: assumed rms value of ballast frames used to pad stats + self.ballast_rms = nn.Parameter(torch.tensor(0.1)) + # ballast_frames: number of ballast frames, in hundreds (will be multiplied by 100) + self.ballast_frames = nn.Parameter(torch.tensor(0.05)) # number of ballast frames, will be multiplied by 100 + self.causal = causal self.name = None def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # x: (seq, batch, channel) # mask: bool, (batch_size, seq_len) + # Note: mask is ignored in causal mode. if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _sequence_norm(x, self.offset, self.scale, mask) + return _sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames, self.causal, mask) scale = limit_param_value( self.scale, min=0.05, max=2.0, training=self.training) @@ -417,14 +442,20 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: offset = limit_param_value( self.offset, min=0.0, max=10.0, training=self.training) + ballast_rms = limit_param_value( + self.ballast_rms, min=0.0, max=10.0, training=self.training) + + ballast_frames = limit_param_value( + self.ballast_frames, min=0.0, max=5.0, training=self.training) # max of 5.0 would be 500 frames + ans = SequenceNormFunction.apply( - x, offset, scale, mask, + x, offset, scale, ballast_rms, ballast_frames, self.causal, mask, ) if random.random() < 0.002: x_rms = (x ** 2).mean().sqrt() ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}, ballast_rms={self.ballast_rms.item()}, ballast_frames*100={100*self.ballast_frames.item()}") return ans diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8d013e5df1..e3f7dfd2b6 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -26,13 +26,16 @@ import torch from encoder_interface import EncoderInterface from scaling import ( + ActivationDropoutAndLinear, + ChunkCausalDepthwiseConv1d, + CosineSimilarityLoss, + CorrelationLimiter, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. OrthogonalLinear, + RmsNorm, + SequenceNorm, SimpleOrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - ActivationDropoutAndLinear, - ChunkCausalDepthwiseConv1d, - CosineSimilarityLoss, ScheduledFloat, FloatLike, SwashR, @@ -43,12 +46,6 @@ ScaleLimiter, with_loss, ) -try: - from scaling import CorrelationLimiter - from scaling import SequenceNorm - from scaling import RmsNorm -except: - pass from torch import Tensor, nn @@ -541,7 +538,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) - self.norm = SequenceNorm() + self.norm = SequenceNorm(causal=causal) def forward( From e0215b003aeaa85e4bc598f71efb45aeb3a1b909 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Mar 2026 16:00:20 +0800 Subject: [PATCH 0918/1191] Reformulate weight decay to have a square, and change wd from 0.2 to 12.5 --- .../ASR/zapformer/combined_scheduler.py | 5 +++-- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/optim.py | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 4be4c16991..81e1ae10a3 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -120,12 +120,13 @@ def print_lr(self, is_verbose, group, lr): class CosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, - min_factor: float = 0.1, + min_factor: float = 0.2, **kwargs): super().__init__(*args, **kwargs) self.min_factor = min_factor def get_lr(self): progress = self.get_progress() - factor = max(self.min_factor, 0.5 * (1.0 + math.cos(math.pi * progress))) + factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + factor = self.min_factor + (1. - self.min_factor) * factor return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 06c05912ff..9ce3f9a5f0 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,7 +1358,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=0.2, + wd=12.5, scale_limits=(1.0, 4.0), ) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 66c8c1c4b8..49b1d7931f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -458,7 +458,7 @@ def scaling_step(group, param, state, grad): scale_ratio = scale / old_scale - delta_scale = (scale_ratio * (1 - lr * wd)) - 1 + delta_scale = (scale_ratio * (1 - (lr * wd) ** 2)) - 1 return param * delta_scale + scale * delta @@ -518,11 +518,11 @@ class TransformedAdam(BatchedOptimizer): def __init__( self, params, - lr=3e-02, + lr=1e-03, beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - wd=0.15, + wd=10, eps=1.0e-08, scale_limits=(0.5, 2.0), ): @@ -926,11 +926,11 @@ class SimpleTransformedAdam(Optimizer): def __init__( self, params, - lr=3e-02, + lr=1e-03, beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) beta2=0.98, - wd=0.15, + wd=10, eps=1.0e-08, scale_limits=(0.5, 2.0), ): @@ -1036,9 +1036,9 @@ def _test_transformed_adam(hidden_dim: int): lr = 0.001 if test == 0: - optim = TransformedAdam(m.named_parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.99) + optim = TransformedAdam(m.named_parameters(), lr=lr, wd=12, eps=1.0e-20, beta1=0.99) elif test == 1: - optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=0.15, eps=1.0e-20, beta1=0.99) + optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=12, eps=1.0e-20, beta1=0.99) num_epochs = 180 From 6372768d3bd5589374607ddac33c7299afa8d020 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Mar 2026 14:00:42 +0800 Subject: [PATCH 0919/1191] Set lr_scale=0.66 in conv_module.depthwise_conv --- egs/librispeech/ASR/zipformer/zipformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index acae9fbbb9..6323ea373d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1726,6 +1726,7 @@ def __init__( self.activation2 = Identity() # for diagnostics self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) + self.depthwise_conv.lr_scale = 0.66 self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, From 64663f8f469eb80f7bffdf02ab4d4f7b19ee5ec1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 20 Feb 2026 13:56:35 +0800 Subject: [PATCH 0920/1191] Increase max_rms in offset_scale_limiter from 0.5 to 1.0. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6323ea373d..8b7e7c6012 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -521,7 +521,7 @@ def __init__( self.embed_dim = embed_dim self.name = None # will be set from training loop - self.offset_scale_limiter = ScaleLimiter(max_rms=0.5) + self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) From 7b5621ab4b073759200770043d5416ec6884f2bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Mar 2026 09:31:15 +0800 Subject: [PATCH 0921/1191] Change cosine to linear LR scheduler --- .../ASR/zapformer/combined_scheduler.py | 15 +++++++++++++++ egs/librispeech/ASR/zapformer/train.py | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 81e1ae10a3..01ba074bb9 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -130,3 +130,18 @@ def get_lr(self): factor = 0.5 * (1.0 + math.cos(math.pi * progress)) factor = self.min_factor + (1. - self.min_factor) * factor return [x * factor for x in self.base_lrs] + + +class LinearLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.05, + **kwargs): + super().__init__(*args, **kwargs) + self.min_factor = min_factor + + def get_lr(self): + progress = self.get_progress() + factor = 1.0 - progress + factor = self.min_factor + (1. - self.min_factor) * factor + return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 9ce3f9a5f0..caafb31242 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,7 +76,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import TransformedAdam -from combined_scheduler import CombinedLRScheduler, CosineLRScheduler +from combined_scheduler import CombinedLRScheduler, CosineLRScheduler, LinearLRScheduler from torch.optim.lr_scheduler import LambdaLR from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -1369,7 +1369,7 @@ def lr_lambda(current_step): progress = current_step / total_steps return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - scheduler = CosineLRScheduler(optimizer, + scheduler = LinearLRScheduler(optimizer, batches_per_epoch=params.batches_per_epoch, num_epochs=params.num_epochs) From 6331dd1e151373cba50d7926434f1bed0f6a5911 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 3 Mar 2026 10:32:28 +0800 Subject: [PATCH 0922/1191] Change for m ultiple jobs running --- egs/librispeech/ASR/zapformer/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index caafb31242..f59e1160de 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,7 +76,11 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import TransformedAdam -from combined_scheduler import CombinedLRScheduler, CosineLRScheduler, LinearLRScheduler +from combined_scheduler import CombinedLRScheduler, CosineLRScheduler +try: + from combined_scheduler import LinearLRScheduler +except: + pass from torch.optim.lr_scheduler import LambdaLR from scaling import ScheduledFloat from subsampling import Conv2dSubsampling From 3e8c9fb5f84f857c826668d53988a19b5133262e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 11:21:03 +0800 Subject: [PATCH 0923/1191] Change LR schedule to stay constant for 0.2 of the duration. --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 01ba074bb9..3eddd832e4 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -135,13 +135,18 @@ def get_lr(self): class LinearLRScheduler(CombinedLRScheduler): def __init__(self, *args, + const_fraction: float = 0.2, # fraction of schedule for which we stay at 1.0 min_factor: float = 0.05, **kwargs): super().__init__(*args, **kwargs) + self.const_fraction = const_fraction self.min_factor = min_factor def get_lr(self): progress = self.get_progress() - factor = 1.0 - progress + # initially: factor is constant at 1.0 until progress==self.const_fraction, then decays to 0 + # at the end. + factor = (1.0 if progress <= self.const_fraction else (1.0 - progress) / (1. - self.const_fraction)) + # then, modify for self.min_factor factor = self.min_factor + (1. - self.min_factor) * factor return [x * factor for x in self.base_lrs] From 57a28a398a4de480d333834879321625eecdb25b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 14:12:08 +0800 Subject: [PATCH 0924/1191] Get muon test working; slight refactoring of the sqrt scale. --- egs/librispeech/ASR/zipformer/muon.py | 22 +++++++++++---------- egs/librispeech/ASR/zipformer/optim.py | 27 +++++++++++++++++++++----- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 7d36475178..9e0189e432 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -179,13 +179,13 @@ def __init__( # Do not use Muon for parameters in adamw_params self.state[p]["use_muon"] = False - def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float: - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr + #def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float: + # A, B = param_shape[:2] + # # We adjust the learning rate and weight decay based on the size of the parameter matrix + # # as describted in the paper + # adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + # adjusted_lr = lr * adjusted_ratio + # return adjusted_lr def step(self, closure=None): """Perform a single optimization step. @@ -236,9 +236,12 @@ def step(self, closure=None): u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - # scale update - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # scale so u should have unit RMS; we remove this factor from + # adjust_lr_for_muon() and simply use the factor of 0.2 below + u = u * (max(p.shape[0], p.shape[1]) ** 0.5) + # multipliying by 0.2 is what's left of adjust_lr_for_muon(0 + adjusted_lr = 0.2 * lr old_scale = scale.clone() @@ -250,7 +253,6 @@ def step(self, closure=None): # apply changes in scale, together with conventional decay. p.data.mul_(scale_ratio * (1 - lr * wd)) - # apply update p.data.add_(u * scale, alpha=-adjusted_lr) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 49b1d7931f..25486ff7dd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1155,11 +1155,29 @@ def _test_muon(hidden_dim: int): adamw_params=[m for m in m.parameters() if m.ndim != 2], lr=1e-03) - scheduler = Sched3(optim, lr_batches=100, power=0.9, warmup_start=0.1, verbose=False) + + num_epochs = 180 + warmup_steps = 0 + # hardcode batches per epoch for now. + total_steps = num_epochs + warmup_start = 0.5 + def lr_lambda(current_step): + if current_step < warmup_steps: + # Linear warm-up + return warmup_start + (1.0 - warmup_start) * current_step / warmup_steps + else: + # Cosine annealing + progress = (current_step - warmup_steps) / (total_steps - warmup_steps) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + + scheduler = LambdaLR(optim, lr_lambda) start = timeit.default_timer() avg_loss = 0.0 - for epoch in range(180): + for epoch in range(num_epochs): + scheduler.step() + # if epoch == 100 and test in [2,3]: # optim.reset_speedup() # check it doesn't crash. @@ -1170,7 +1188,6 @@ def _test_muon(hidden_dim: int): # diagnostic = diagnostics.attach_diagnostics(m, opts) for n, (x, y) in enumerate(train_pairs): - scheduler.step_batch() y_out = m(x) loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: @@ -1223,5 +1240,5 @@ def _test_muon(hidden_dim: int): else: hidden_dim = 200 - #_test_muon(hidden_dim) - _test_transformed_adam(hidden_dim) + _test_muon(hidden_dim) + #_test_transformed_adam(hidden_dim) From 7e8f8b1adcbae7e69888865c27d481bf97223e46 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 14:36:39 +0800 Subject: [PATCH 0925/1191] Implement row and column scales with moving buffer, scale both before and after newton-schulz --- egs/librispeech/ASR/zipformer/muon.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 9e0189e432..e54d753960 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -218,9 +218,13 @@ def step(self, closure=None): # calc update state = self.state[p] if "momentum_buffer" not in state: + state["delta2_buffer0"] = torch.ones(g.shape[0], device=g.device, dtype=g.dtype) + state["delta2_buffer1"] = torch.ones(g.shape[1], device=g.device, dtype=g.dtype) state["momentum_buffer"] = torch.zeros_like(g) state["scale"] = torch.tensor(1.0, device=g.device) # scalar state["scale_grad_buffer"] = torch.tensor(0.0, device=g.device) # scalar + delta2_buffer0 = state["delta2_buffer0"] + delta2_buffer1 = state["delta2_buffer1"] buf = state["momentum_buffer"] scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] @@ -234,13 +238,26 @@ def step(self, closure=None): else: g = buf + eps = 1.0e-08 + + # we'll scale both before and after the newton-schulz + row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt().unsqueeze(-1) * (delta2_buffer1 + eps).sqrt()) + + g = g * row_col_scale + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) # scale so u should have unit RMS; we remove this factor from # adjust_lr_for_muon() and simply use the factor of 0.2 below u = u * (max(p.shape[0], p.shape[1]) ** 0.5) - # multipliying by 0.2 is what's left of adjust_lr_for_muon(0 + beta2 = 0.98 + delta2_buffer0.mul_(beta2).add_((u ** 2).mean(dim=1), alpha=(1 - beta2)) + delta2_buffer1.mul_(beta2).add_((u ** 2).mean(dim=0), alpha=(1 - beta2)) + + u = u * row_col_scale + + # multiplying by 0.2 is what's left of adjust_lr_for_muon() adjusted_lr = 0.2 * lr old_scale = scale.clone() From bfbe38e77f95afaa727f372788b77a72278ea490 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 15:54:40 +0800 Subject: [PATCH 0926/1191] Changes to LR schedule and weight decay formula (make it squared), this does not help in the test setup though. --- egs/librispeech/ASR/zipformer/muon.py | 12 ++++++------ egs/librispeech/ASR/zipformer/optim.py | 19 ++++++++----------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index e54d753960..ebaa18384e 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -139,13 +139,13 @@ class Muon(torch.optim.Optimizer): adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. - adamw_wd: The weight decay for the internal AdamW. + wd: weight decay for muon and adamw, this is a squared type of weight decay, requires a large value + which dimensionally is like an inverse of a parameter rms """ - def __init__( self, lr=1e-3, - wd=0.1, + wd=10.0, # weight decay is a squared type, needs larger wd value, muon_params=None, momentum=0.95, nesterov=True, @@ -153,7 +153,7 @@ def __init__( adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, - scale_limits=(0.5, 2.0), + scale_limits=(1.0, 4.0), ): defaults = dict( lr=lr, @@ -268,7 +268,7 @@ def step(self, closure=None): scale_ratio = scale / old_scale # apply changes in scale, together with conventional decay. - p.data.mul_(scale_ratio * (1 - lr * wd)) + p.data.mul_(scale_ratio * (1 - (lr * wd) ** 2)) # apply update p.data.add_(u * scale, alpha=-adjusted_lr) @@ -301,7 +301,7 @@ def step(self, closure=None): bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - lr * weight_decay) + p.data.mul_(1 - (lr * weight_decay) ** 2) p.data.add_(g, alpha=-lr / scale) return loss diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 25486ff7dd..e90f38a818 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1153,23 +1153,20 @@ def _test_muon(hidden_dim: int): optim = Muon(muon_params=[m for m in m.parameters() if m.ndim == 2], adamw_params=[m for m in m.parameters() if m.ndim != 2], - lr=1e-03) - + lr=0.5e-03, + wd=12.0) num_epochs = 180 - warmup_steps = 0 # hardcode batches per epoch for now. total_steps = num_epochs - warmup_start = 0.5 + constant_fraction = 0.25 + def lr_lambda(current_step): - if current_step < warmup_steps: - # Linear warm-up - return warmup_start + (1.0 - warmup_start) * current_step / warmup_steps + progress = current_step / total_steps + if progress < constant_fraction: + return 1.0 else: - # Cosine annealing - progress = (current_step - warmup_steps) / (total_steps - warmup_steps) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - + return (1.0 - progress) / (1.0 - constant_fraction) scheduler = LambdaLR(optim, lr_lambda) From 0377c3e08b29c3ff27a710002c9d027f1d3c9fd1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 16:26:18 +0800 Subject: [PATCH 0927/1191] Remove the older scaling method --- egs/librispeech/ASR/zipformer/muon.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index ebaa18384e..cde46e475a 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -94,7 +94,6 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = X.T # Ensure spectral 4-norm is at most 1 eps = 1e-7 - X = X / ((X ** 2).sum(dim=0) + eps**2).sqrt() # normalize columns X = X / (norm4(X) + eps) # Perform the NS iterations for _ in range(steps): @@ -102,10 +101,6 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X - # now x: (rows, cols) with rows <= cols - scale = (X.shape[0] / X.shape[1]) ** 0.5 # adjust so overall scale is not changed by next line. - X = X * (scale / ((X ** 2).sum(dim=0) + eps**2).sqrt()) - if G.size(0) > G.size(1): X = X.T From 6a301b02f8574e5e95c3f4149f93ce1968d542ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 16:41:04 +0800 Subject: [PATCH 0928/1191] Bug fixes and configuration changes. --- egs/librispeech/ASR/zipformer/muon.py | 56 ++++++++++++++------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index cde46e475a..0331cb1f3c 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -75,7 +75,7 @@ def prod(l): if diffs[i-1] == min_diff: return prod(shape[:i]), prod(shape[i:]) -def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor": +def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int, state: dict) -> "torch.Tensor": """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. @@ -92,8 +92,21 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" X = G.bfloat16() if G.size(0) > G.size(1): X = X.T - # Ensure spectral 4-norm is at most 1 + + if "delta2_buffer0" not in state: + state["delta2_buffer0"] = torch.ones(X.shape[0], device=X.device, dtype=X.dtype) + state["delta2_buffer1"] = torch.ones(X.shape[1], device=X.device, dtype=X.dtype) + delta2_buffer0 = state["delta2_buffer0"] + delta2_buffer1 = state["delta2_buffer1"] + + eps = 1e-7 + + # we'll scale both before and after the newton-schulz + row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt().unsqueeze(-1) * (delta2_buffer1 + eps).sqrt()) + X = X * row_col_scale + + # Ensure spectral 4-norm is at most 1 X = X / (norm4(X) + eps) # Perform the NS iterations for _ in range(steps): @@ -101,12 +114,17 @@ def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor" B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X - if G.size(0) > G.size(1): - X = X.T + # the following scales so if the newton-schulz was exact, the elements of X would have unit RMS. + X = X * (max(X.shape[0], X.shape[1]) ** 0.5) + X2 = X ** 2 + beta = 0.98 + delta2_buffer0.mul_(beta).add_(X2.mean(dim=1), alpha=(1 - beta)) + delta2_buffer1.mul_(beta).add_(X2.mean(dim=0), alpha=(1 - beta)) + X = X * row_col_scale - if random.random() < 0.0001: - logging.info(f"zeropower_via_newtonschulz5: shape={X.shape}, singular-value-rms={X.norm()/(min(X.shape[0],X.shape[1])**0.5)}") + if G.size(0) > G.size(1): + X = X.T return X.reshape(orig_shape) @@ -148,7 +166,7 @@ def __init__( adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, - scale_limits=(1.0, 4.0), + scale_limits=(0.5, 4.0), ): defaults = dict( lr=lr, @@ -213,13 +231,9 @@ def step(self, closure=None): # calc update state = self.state[p] if "momentum_buffer" not in state: - state["delta2_buffer0"] = torch.ones(g.shape[0], device=g.device, dtype=g.dtype) - state["delta2_buffer1"] = torch.ones(g.shape[1], device=g.device, dtype=g.dtype) state["momentum_buffer"] = torch.zeros_like(g) state["scale"] = torch.tensor(1.0, device=g.device) # scalar state["scale_grad_buffer"] = torch.tensor(0.0, device=g.device) # scalar - delta2_buffer0 = state["delta2_buffer0"] - delta2_buffer1 = state["delta2_buffer1"] buf = state["momentum_buffer"] scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] @@ -235,24 +249,12 @@ def step(self, closure=None): eps = 1.0e-08 - # we'll scale both before and after the newton-schulz - row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt().unsqueeze(-1) * (delta2_buffer1 + eps).sqrt()) - - g = g * row_col_scale - - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - - # scale so u should have unit RMS; we remove this factor from - # adjust_lr_for_muon() and simply use the factor of 0.2 below - u = u * (max(p.shape[0], p.shape[1]) ** 0.5) - - beta2 = 0.98 - delta2_buffer0.mul_(beta2).add_((u ** 2).mean(dim=1), alpha=(1 - beta2)) - delta2_buffer1.mul_(beta2).add_((u ** 2).mean(dim=0), alpha=(1 - beta2)) - u = u * row_col_scale + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], state=state) - # multiplying by 0.2 is what's left of adjust_lr_for_muon() + # multiplying by 0.2 is what's left of adjust_lr_for_muon(), + # we used the factor of (max(p.shape[0], p.shape[1]) ** 0.5) inside + # zeropower_via_newtonschulz5. adjusted_lr = 0.2 * lr old_scale = scale.clone() From a3f7e7bb6704bf1ebf2980fa392cb089d010d6b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Mar 2026 17:01:40 +0800 Subject: [PATCH 0929/1191] Change train.py to use muon, now supporting parameter groups. --- egs/librispeech/ASR/zapformer/train.py | 16 +++++---------- egs/librispeech/ASR/zipformer/muon.py | 28 ++++---------------------- egs/librispeech/ASR/zipformer/optim.py | 3 +-- 3 files changed, 10 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index f59e1160de..2651d92135 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,6 +76,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import TransformedAdam +from muon import Muon from combined_scheduler import CombinedLRScheduler, CosineLRScheduler try: from combined_scheduler import LinearLRScheduler @@ -1359,19 +1360,12 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = TransformedAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + optimizer = Muon( + get_parameter_groups_with_lrs(model, lr=params.base_lr), lr=params.base_lr, wd=12.5, - scale_limits=(1.0, 4.0), - ) - - # hardcode batches per epoch for now. - total_steps = 4550 * params.num_epochs - def lr_lambda(current_step): - # Cosine annealing - progress = current_step / total_steps - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + scale_limits=(0.5, 4.0), + ) scheduler = LinearLRScheduler(optimizer, batches_per_epoch=params.batches_per_epoch, diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py index 0331cb1f3c..df69d1c166 100644 --- a/egs/librispeech/ASR/zipformer/muon.py +++ b/egs/librispeech/ASR/zipformer/muon.py @@ -157,13 +157,12 @@ class Muon(torch.optim.Optimizer): """ def __init__( self, + params, lr=1e-3, wd=10.0, # weight decay is a squared type, needs larger wd value, - muon_params=None, momentum=0.95, nesterov=True, ns_steps=5, - adamw_params=None, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, scale_limits=(0.5, 4.0), @@ -178,27 +177,7 @@ def __init__( adamw_eps=adamw_eps, scale_limits=scale_limits, ) - - params = list(muon_params) - adamw_params = list(adamw_params) if adamw_params is not None else [] - params.extend(adamw_params) super().__init__(params, defaults) - # Sort parameters into those for which we will use Muon, and those for which we will not - for p in muon_params: - # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim > 1, p.ndim - self.state[p]["use_muon"] = True - for p in adamw_params: - # Do not use Muon for parameters in adamw_params - self.state[p]["use_muon"] = False - - #def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float: - # A, B = param_shape[:2] - # # We adjust the learning rate and weight decay based on the size of the parameter matrix - # # as describted in the paper - # adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - # adjusted_lr = lr * adjusted_ratio - # return adjusted_lr def step(self, closure=None): """Perform a single optimization step. @@ -214,7 +193,7 @@ def step(self, closure=None): for group in self.param_groups: # Muon loop - params = [p for p in group["params"] if self.state[p]["use_muon"]] + params = [p for p in group["params"] if p.numel() != max(p.shape, default=1)] lr = group["lr"] wd = group["wd"] momentum = group["momentum"] @@ -271,7 +250,8 @@ def step(self, closure=None): p.data.add_(u * scale, alpha=-adjusted_lr) # Adam backup - params = [p for p in group["params"] if not self.state[p]["use_muon"]] + params = [p for p in group["params"] if p.numel() == max(p.shape, default=1)] + lr = group["lr"] beta1, beta2 = group["adamw_betas"] eps = group["adamw_eps"] diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index e90f38a818..25c34a6b64 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1151,8 +1151,7 @@ def _test_muon(hidden_dim: int): for _ in range(20) ] - optim = Muon(muon_params=[m for m in m.parameters() if m.ndim == 2], - adamw_params=[m for m in m.parameters() if m.ndim != 2], + optim = Muon(m.parameters(), lr=0.5e-03, wd=12.0) From af6d43ca49444f04463b0dc9203b6511a0421951 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 10:24:17 +0800 Subject: [PATCH 0930/1191] Decrease wd from 12.5 to 10.0 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2651d92135..b12986e538 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1363,7 +1363,7 @@ def run(rank, world_size, args): optimizer = Muon( get_parameter_groups_with_lrs(model, lr=params.base_lr), lr=params.base_lr, - wd=12.5, + wd=10.0, scale_limits=(0.5, 4.0), ) From 7dcf06a30228513c77ac7e6a6059875854634991 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 13:44:58 +0800 Subject: [PATCH 0931/1191] Change base_step of optim.py to only normalize grad globally. --- egs/librispeech/ASR/zipformer/optim.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 25c34a6b64..4db87dcb2c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -129,6 +129,8 @@ def batched_params(self, param_group, group_params_names): def base_step(group, state, grad): # computes basic Adam normalized-grad using beta2 (dividing by gradient stddev) only. no momentum yet. + # this normalied-grad is normalized only at the whole tensor level for now. + beta2 = group["beta2"] eps = group["eps"] # p shape: (batch_size,) or (batch_size, 1, [1,..]) @@ -136,10 +138,17 @@ def base_step(group, state, grad): exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) or (batch_size, 1, [1,..]) except KeyError: assert state["step"] < 2 - exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + batch_size = grad.shape[0] + stats_shape = [batch_size] + [1] * (len(grad.shape) - 1) + exp_avg_sq = torch.zeros(*stats_shape, device=grad.device, dtype=torch.float) state["exp_avg_sq"] = exp_avg_sq - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + mean_dims = list(range(1, grad.ndim)) + grad2 = (grad ** 2) + if len(mean_dims) > 0: + grad2 = grad2.mean(dim=mean_dims, keepdim=True) + exp_avg_sq.mul_(beta2).add_(grad2, alpha=1 - beta2) # bias_correction2 is like in Adam. # slower update at the start will help stability anyway. @@ -1236,5 +1245,5 @@ def lr_lambda(current_step): else: hidden_dim = 200 - _test_muon(hidden_dim) - #_test_transformed_adam(hidden_dim) + #_test_muon(hidden_dim) + _test_transformed_adam(hidden_dim) From b2989394b1560a2f6a7fd2720503818ee534f96f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 13:46:38 +0800 Subject: [PATCH 0932/1191] Take train.py from branch 2130, reverting to use TransformedAdam with newer scheduler. --- .../ASR/zapformer/combined_scheduler.py | 4 ++-- egs/librispeech/ASR/zapformer/train.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 3eddd832e4..58e120728e 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -136,7 +136,7 @@ class LinearLRScheduler(CombinedLRScheduler): def __init__(self, *args, const_fraction: float = 0.2, # fraction of schedule for which we stay at 1.0 - min_factor: float = 0.05, + min_factor: float = 0.1, **kwargs): super().__init__(*args, **kwargs) self.const_fraction = const_fraction @@ -148,5 +148,5 @@ def get_lr(self): # at the end. factor = (1.0 if progress <= self.const_fraction else (1.0 - progress) / (1. - self.const_fraction)) # then, modify for self.min_factor - factor = self.min_factor + (1. - self.min_factor) * factor + factor = max(factor, self.min_factor) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b12986e538..f59e1160de 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import TransformedAdam -from muon import Muon from combined_scheduler import CombinedLRScheduler, CosineLRScheduler try: from combined_scheduler import LinearLRScheduler @@ -1360,12 +1359,19 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = Muon( - get_parameter_groups_with_lrs(model, lr=params.base_lr), + optimizer = TransformedAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=10.0, - scale_limits=(0.5, 4.0), - ) + wd=12.5, + scale_limits=(1.0, 4.0), + ) + + # hardcode batches per epoch for now. + total_steps = 4550 * params.num_epochs + def lr_lambda(current_step): + # Cosine annealing + progress = current_step / total_steps + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) scheduler = LinearLRScheduler(optimizer, batches_per_epoch=params.batches_per_epoch, From 76983d0d4ec17cd17b0171ed8b30c29ac92167ed Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 14:35:22 +0800 Subject: [PATCH 0933/1191] Refactoring to ensure things are matrix shaped outside compute_prod3 --- egs/librispeech/ASR/zipformer/optim.py | 38 ++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 4db87dcb2c..de7ab08782 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -371,6 +371,27 @@ def scale_by(x, beta1): logging.info(f"shape={x.shape}, beta1={beta1}, alpha={alpha}, alpha/(((1-beta1)**2)/dim)={alpha/(((1-beta1)**2)/max(rows,cols))}, post_scale={post_scale}, dot_prod_ratio={dot_prod2/dot_prod1}") +def get_matrix_shape(shape): + shape = list(shape) + batch_size = shape[0] # batch size is 1st element of shape + shape = shape[1:] + def prod(l): + ans = l[0] + for n in l[1:]: + ans = ans * n + return ans + n = len(shape) + diffs = [ ] + for i in range(1, n): + prod1 = prod(shape[:i]) + prod2 = prod(shape[i:]) + diff = abs(prod1 - prod2) + diffs.append(diff) + min_diff = min(diffs) + for i in range(1, n): + if diffs[i-1] == min_diff: + return batch_size, prod(shape[:i]), prod(shape[i:]) + def momentum_step(group, state, grad): delta = base_step(group, state, grad) @@ -407,14 +428,18 @@ def min_sum_scale(x, y): stored_delta.add_(delta) if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): + + delta_reshaped = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) + # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. # this should be configurable. linear_decay_scale = 0.25 - stored_delta.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) + delta_reshaped.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) excess_scale = 2.5 - x3 = stored_delta * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times + x3 = delta_reshaped * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times + compute_prod3_inplace(x3) # actually computes 3rd power of its arg divided by max(rows, cols)**2 # the factor of 0.5 says we only want to go, at most, half the way to the point which # would give us the minimum 'x'. this is to prevent the largest eigs overshooting @@ -422,7 +447,14 @@ def min_sum_scale(x, y): # the largest singular value; or to prevent the largest singular value from going to # zero if it does dominate. alpha = (0.5 * min_sum_scale(stored_delta, x3)).clamp(min=-1) - stored_delta.add_(x3 * alpha) + delta_reshaped.add_(x3 * alpha) + + if random.random() < 0.001: + rel_scale = (delta_reshaped ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) + logging.info(f"rel_scale = {rel_scale.item()}") + + if not stored_delta.untyped_storage() is delta_reshaped.untyped_storage(): + stored_delta[:] = delta_reshaped.reshape(*stored_delta.shape) else: stored_delta.mul_(beta1) From 5d8e775f1549082675a07314ed054e39c904b463 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 15:37:23 +0800 Subject: [PATCH 0934/1191] Implement version of TransformedAdam with x3 decay, where row/column scales are applied before and after the x3 decay. --- egs/librispeech/ASR/zipformer/optim.py | 43 +++++++++++++++++++------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index de7ab08782..6840f5b9dd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -428,17 +428,25 @@ def min_sum_scale(x, y): stored_delta.add_(delta) if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): - - delta_reshaped = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) - + d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. # this should be configurable. - linear_decay_scale = 0.25 - delta_reshaped.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) + d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) excess_scale = 2.5 - x3 = delta_reshaped * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times + x3 = d * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times + + if "delta2_buffer0" not in state: + state["delta2_buffer0"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) + state["delta2_buffer1"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) + delta2_buffer0 = state["delta2_buffer0"] + delta2_buffer1 = state["delta2_buffer1"] + + # we'll scale both before and after the cubing + row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt() * (delta2_buffer1 + eps).sqrt()) + + x3 = x3 * row_col_scale #note, we are before computing the cubed part. compute_prod3_inplace(x3) # actually computes 3rd power of its arg divided by max(rows, cols)**2 # the factor of 0.5 says we only want to go, at most, half the way to the point which @@ -447,14 +455,27 @@ def min_sum_scale(x, y): # the largest singular value; or to prevent the largest singular value from going to # zero if it does dominate. alpha = (0.5 * min_sum_scale(stored_delta, x3)).clamp(min=-1) - delta_reshaped.add_(x3 * alpha) + # we divide x3 by row_col_scale to "un-normalize". + d.add_(x3 * alpha / row_col_scale) + + if random.random() < 0.0005: + rel_scale = (d ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) + logging.info(f"shape={stored_delta.shape}, rel_scale = {rel_scale.item()}") + + if not stored_delta.untyped_storage() is d.untyped_storage(): + stored_delta[:] = d.reshape(*stored_delta.shape) + beta = beta1 # use this beta for row/col scales + d = d * row_col_scale # half-normalized d + assumed_scale = 0.5 * ((1 - beta1**2)**-0.5) # assumed scalie of d + d2 = (d / assumed_scale) ** 2 if random.random() < 0.001: - rel_scale = (delta_reshaped ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) - logging.info(f"rel_scale = {rel_scale.item()}") + logging.info(f"shape={stored_delta.shape}, mean of normalized d2 is {d2.mean().item()}") + delta2_buffer0.mul_(beta).add_(d2.mean(dim=2, keepdim=True), alpha=(1 - beta)) + delta2_buffer1.mul_(beta).add_(d2.mean(dim=1, keepdim=True), alpha=(1 - beta)) + d = d * row_col_scale # fully-normalized d - if not stored_delta.untyped_storage() is delta_reshaped.untyped_storage(): - stored_delta[:] = delta_reshaped.reshape(*stored_delta.shape) + stored_delta = d.reshape(*stored_delta.shape) # note: permanent buffer is not updated. else: stored_delta.mul_(beta1) From 28c8a9659660c15268f5273f6953bdaf49fed104 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 15:48:21 +0800 Subject: [PATCH 0935/1191] Bug fix --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6840f5b9dd..1489744dfd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -454,7 +454,7 @@ def min_sum_scale(x, y): # and having the direction change sign, in a situation where we are not dominated by # the largest singular value; or to prevent the largest singular value from going to # zero if it does dominate. - alpha = (0.5 * min_sum_scale(stored_delta, x3)).clamp(min=-1) + alpha = (0.5 * min_sum_scale(d, x3)).clamp(min=-1) # we divide x3 by row_col_scale to "un-normalize". d.add_(x3 * alpha / row_col_scale) From db1bd2e0f8c4a6e7f7b77b8a87d21d6ec5a5395a Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 5 Mar 2026 17:43:59 +0800 Subject: [PATCH 0936/1191] add streaming forward for zapformer --- .../ASR/zapformer/streaming_decode.py | 894 +++++++++++++++++- egs/librispeech/ASR/zapformer/train.py | 1 + .../ASR/zipformer/decode_stream.py | 3 +- egs/librispeech/ASR/zipformer/scaling.py | 70 +- egs/librispeech/ASR/zipformer/subsampling.py | 92 +- egs/librispeech/ASR/zipformer/zipformer.py | 736 +++++++------- 6 files changed, 1403 insertions(+), 393 deletions(-) mode change 120000 => 100755 egs/librispeech/ASR/zapformer/streaming_decode.py diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py deleted file mode 120000 index e31da07d01..0000000000 --- a/egs/librispeech/ASR/zapformer/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py new file mode 100755 index 0000000000..5c480e117e --- /dev/null +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -0,0 +1,893 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import LibriSpeech, GigaSpeech, AsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet, set_caching_enabled +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="""If True, decode gigaspeech in addition to librispeech test sets.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*5:(i+1)*5] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_caches(batch_size, device) + + embed_states = model.encoder_embed.get_init_cache(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 5 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 5 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 5 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_value: (left_context_len, batch_size, value_dim) + cached_value = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_conv: (batch_size, channels, left_pad) + cached_conv = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=0 + ) + # cached_norm_stats: (batch_size, ...) + cached_norm_stats = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=0 + ) + # cached_norm_len: (batch_size, ...) + cached_norm_len = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 5 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 5 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 5 + # chunk dim=1 for attention maps + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + cached_value_list = batch_states[layer_offset + 1].chunk(chunks=batch_size, dim=1) + + # chunk dim=0 for conv and norm stats + cached_conv_list = batch_states[layer_offset + 2].chunk(chunks=batch_size, dim=0) + cached_norm_stats_list = batch_states[layer_offset + 3].chunk(chunks=batch_size, dim=0) + cached_norm_len_list = batch_states[layer_offset + 4].chunk(chunks=batch_size, dim=0) + + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_value_list[i], + cached_conv_list[i], + cached_norm_stats_list[i], + cached_norm_len_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cache=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + caches=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + tail_length = chunk_size * 2 + 7 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + if params.label: + params.suffix += f"-{params.label}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeech(args.manifest_dir) + + test_sets = [] + test_cuts = [] + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_sets += ["dev-clean", "dev-other", "test-clean", "test-other"] + test_cuts += [dev_clean_cuts, dev_other_cuts, test_clean_cuts, test_other_cuts] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + test_sets += ["dev", "test"] + test_cuts += [dev_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d8e9b7e458..c055853234 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -669,6 +669,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, out_channels=lookup(params, "embed_dim"), + causal=params.causal, ) return encoder_embed diff --git a/egs/librispeech/ASR/zipformer/decode_stream.py b/egs/librispeech/ASR/zipformer/decode_stream.py index d6918bf328..a1bf671bf5 100644 --- a/egs/librispeech/ASR/zipformer/decode_stream.py +++ b/egs/librispeech/ASR/zipformer/decode_stream.py @@ -75,8 +75,7 @@ def __init__( self.done_frames: int = 0 # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 + self.pad_length = 7 if params.decoding_method == "greedy_search": self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3bc82ca273..954d4665ea 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -362,6 +362,45 @@ def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor return x * ((scale * scales) + offset) +# all arg tensors are scalars +def _sequence_norm_streaming( + x: Tensor, + offset: Tensor, + scale: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Streaming inference forward for _sequence_norm. We assume that ballast_frames and ballast_rms + are already included in cached_stats_sum and cached_len. + + Args: + x: (seq_len, batch_size, channels) + offset: scalar + scale: scalar + cached_stats_sum: (batch_size,) + cached_len: (batch_size,) + + Returns: + - normalized x, (seq_len, batch_size, channels) + - updated cached_stats_sum, (batch_size,) + - updated cached_len, (batch_size,) + """ + stats = (x ** 2).mean(dim=2, keepdim=True) # (seq_len, batch_size, 1) + + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + cached_stats_sum.unsqueeze(-1) + lengths = cached_len[:, None] + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + # update cached_stats_sum and cached_len for the next chunk + cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) + cached_len = cached_len + T + + scales = (lengths / stats).sqrt() # (T, batch_size, 1) + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset), cached_stats_sum, cached_len + + class SequenceNormFunction(torch.autograd.Function): @staticmethod def forward( @@ -427,7 +466,6 @@ def __init__( self.causal = causal self.name = None - def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: # x: (seq, batch, channel) # mask: bool, (batch_size, seq_len) @@ -459,6 +497,30 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: return ans + @torch.jit.export + def get_init_cache(self, batch_size: int): + """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. + """ + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * self.ballast_frames.abs() + ballast = ballast_frames * (self.ballast_rms ** 2) + + cached_stats_sum = ballast.unsqueeze(0).repeat(batch_size) # (batch_size,) + cached_len = ballast_frames.unsqueeze(0).repeat(batch_size) # (batch_size,) + + return cached_stats_sum, cached_len + + def streaming_forward( + self, + x: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + + x, cached_stats_sum, cached_len = _sequence_norm_streaming( + x, self.offset, self.scale, cached_stats_sum, cached_len) + return x, cached_stats_sum, cached_len # assume layout: (time, batch, channel) @@ -1535,6 +1597,7 @@ def __init__( # both of these are added to a default scale of 1.0. self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) self.kernel_size = kernel_size + self.left_pad = half_kernel_size - 1 with torch.no_grad(): self.causal_conv.weight[:] *= initial_scale @@ -1553,11 +1616,10 @@ def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: """ (batch_size, num_channels, seq_len) = x.shape - # half_kernel_size = self.kernel_size + 1 // 2 # left_pad is half_kernel_size - 1 where half_kernel_size is the size used # in the causal conv. It's the amount by which we must pad on the left, # to make the convolution causal. - left_pad = self.kernel_size // 2 + left_pad = self.left_pad if chunk_size < 0 or chunk_size > seq_len: chunk_size = seq_len @@ -1622,7 +1684,7 @@ def streaming_forward( # left_pad is half_kernel_size - 1 where half_kernel_size is the size used # in the causal conv. It's the amount by which we must pad on the left, # to make the convolution causal. - left_pad = self.kernel_size // 2 + left_pad = self.left_pad # Pad cache assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 839e848e69..0959f417fc 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -61,17 +61,25 @@ def __init__( channels: int, hidden_ratio: int = 3, kernel_size: Tuple[int, int] = (7, 7), + causal: bool = False, ): super().__init__() - self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 + self.causal = causal hidden_channels = channels * hidden_ratio + if not causal: + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + else: + padding = (0, kernel_size[1] // 2) + self.left_pad = kernel_size[0] - 1 + self.depthwise_conv = nn.Conv2d( in_channels=channels, out_channels=channels, groups=channels, kernel_size=kernel_size, - padding=self.padding, + padding=padding, ) self.pointwise_conv1 = nn.Conv2d( @@ -86,7 +94,6 @@ def __init__( kernel_size=1, ) - def forward( self, x: Tensor, ) -> Tensor: @@ -96,7 +103,12 @@ def forward( The returned value has the same shape as x. """ bypass = x + + if self.causal: + x = nn.functional.pad(x, (0, 0, self.left_pad, 0)) x = self.depthwise_conv(x) + assert x.shape == bypass.shape, (x.shape, bypass.shape) + x = self.pointwise_conv1(x) x = self.activation(x) x = self.pointwise_conv2(x) @@ -104,51 +116,38 @@ def forward( x = bypass + x return x - + def streaming_forward( self, x: Tensor, - cached_left_pad: Tensor, + cache: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) + cache: (batch_size, num_channels, left_pad, num_freqs) Returns: - The returned value has the same shape as x. - - Updated cached_left_pad. + - Updated cache. """ - padding = self.padding + bypass = x - # The length without right padding for depth-wise conv - T = x.size(2) - padding[0] + # Pad left side with cache, and update cache + assert cache.size(2) == self.left_pad + x = torch.cat([cache, x], dim=2) + cache = x[:, :, -self.left_pad :, :] - bypass = x[:, :, :T, :] + x = self.depthwise_conv(x) + assert x.shape == bypass.shape, (x.shape, bypass.shape) - # Pad left side - assert cached_left_pad.size(2) == padding[0], ( - cached_left_pad.size(2), - padding[0], - ) - x = torch.cat([cached_left_pad, x], dim=2) - # Update cached left padding - cached_left_pad = x[:, :, T : padding[0] + T, :] - - # depthwise_conv - x = torch.nn.functional.conv2d( - x, - weight=self.depthwise_conv.weight, - bias=self.depthwise_conv.bias, - padding=(0, padding[1]), - groups=self.depthwise_conv.groups, - ) x = self.pointwise_conv1(x) x = self.activation(x) x = self.pointwise_conv2(x) x = bypass + x - return x, cached_left_pad + + return x, cache class Conv2dSubsampling(nn.Module): @@ -169,6 +168,7 @@ def __init__( layer1_channels: int = 16, layer2_channels: int = 64, layer3_channels: int = 128, + causal: bool = False, ) -> None: """ Args: @@ -220,15 +220,13 @@ def __init__( SwashR(), ) - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7), causal=causal) # (in_channels-3)//4 self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.layer3_channels = layer3_channels - # scale it up a bit, else the output is quite small. self.out = ScaledLinear(self.out_width * layer3_channels, out_channels) @@ -277,7 +275,7 @@ def streaming_forward( self, x: torch.Tensor, x_lens: torch.Tensor, - cached_left_pad: Tensor, + cache: Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Subsample x. @@ -286,6 +284,8 @@ def streaming_forward( Its shape is (N, T, idim). x_lens: A tensor of shape (batch_size,) containing the number of frames in + cache: + The cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) Returns: - a tensor of shape (N, (T-7)//2, odim) @@ -298,10 +298,7 @@ def streaming_forward( # T' = (T-7)//2 x = self.conv(x) - # T' = (T-7)//2-3 - x, cached_left_pad = self.convnext.streaming_forward( - x, cached_left_pad=cached_left_pad - ) + x, cache = self.convnext.streaming_forward(x, cache=cache) # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) b, c, t, f = x.size() @@ -309,24 +306,21 @@ def streaming_forward( x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, T', out_width * layer3_channels)) + x = self.out(x) # Now x is of shape (N, T', odim) if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert self.convnext.padding[0] == 3 - # The ConvNeXt module needs 3 frames of right padding after subsampling - x_lens = (x_lens - 7) // 2 - 3 + x_lens = (x_lens - 7) // 2 else: with warnings.catch_warnings(): warnings.simplefilter("ignore") - # The ConvNeXt module needs 3 frames of right padding after subsampling - assert self.convnext.padding[0] == 3 - x_lens = (x_lens - 7) // 2 - 3 + x_lens = (x_lens - 7) // 2 assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) - return 0.15 * x, x_lens, cached_left_pad + return 0.15 * x, x_lens, cache @torch.jit.export - def get_init_states( + def get_init_cache( self, batch_size: int = 1, device: torch.device = torch.device("cpu"), @@ -335,11 +329,9 @@ def get_init_states( It is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) """ - left_pad = self.convnext.padding[0] + left_pad = self.convnext.left_pad freq = self.out_width channels = self.layer3_channels - cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( - device - ) + cache = torch.zeros(batch_size, channels, left_pad, freq, device=device) - return cached_embed_left_pad + return cache diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e3f7dfd2b6..16b733b0a5 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -50,6 +50,8 @@ from torch import Tensor, nn +from icefall.utils import make_pad_mask + class Zipformer2(EncoderInterface): """ @@ -132,8 +134,8 @@ def _to_tuple(x): self.conv_params = conv_params = _to_tuple(conv_params) self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames + self.chunk_size = (chunk_size,) if isinstance(chunk_size, int) else chunk_size + self.left_context_frames = (left_context_frames,) if isinstance(left_context_frames, int) else left_context_frames # each one will be Zipformer2Encoder or OrthogonalDownsample or OrthogonalUpsample encoders = [] @@ -254,7 +256,7 @@ def forward( T = x.shape[0] x = module( x, - chunk_size=chunk_size, + chunk_size=chunk_size // ds if chunk_size > 0 else -1, src_key_padding_mask=( None if src_key_padding_mask is None @@ -322,13 +324,12 @@ def _get_attn_mask( logging.info(f"attn_mask = {attn_mask}") return attn_mask - def streaming_forward( self, x: Tensor, x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, + caches: List[Tensor], + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, List[Tensor]]: """ Args: @@ -337,60 +338,87 @@ def streaming_forward( x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*5:(i+1)*5] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv) + caches: list of cached tensors of all encoder layers. For layer-i, + caches[i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len). src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + Returns: - Return a tuple containing 2 tensors: - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. - - updated states + - updated caches: an updated list of cache tensors. """ - new_states = [] + orig_seq_len = x.shape[0] + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), dim=0) + + if src_key_padding_mask is not None: + left_context_frames = src_key_padding_mask.shape[1] - orig_seq_len + assert left_context_frames == self.left_context_frames[0] + if pad > 0: + src_key_padding_mask = torch.cat( + (src_key_padding_mask[:, :left_context_frames], + pad_mask(src_key_padding_mask[:, left_context_frames:], x.shape[0])), + dim=1, + ) + + new_caches = [] layer_offset = 0 for i, module in enumerate(self.encoders): num_layers = module.num_layers ds = self.downsampling_factor[i] - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 5], + x = downsample_by(x, ds) + + # Slice out the specific caches for the current module + module_caches = caches[layer_offset * 5 : (layer_offset + num_layers) * 5] + + x, new_module_caches = module.streaming_forward( + src=x, + caches=module_caches, left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), ) + layer_offset += num_layers - new_states += new_layer_states + new_caches.extend(new_module_caches) - x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. + x = upsample_by(x, ds) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 + # Output downsampling and normalization + od = self.output_downsampling_factor + x = downsample_by(x, od) + + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding - return x, lengths, new_states + if od > 1: + x_lens = (x_lens + od - 1) // od + + x = self.out_norm(x) + return x, x_lens, new_caches + @torch.jit.export - def get_init_states( + def get_init_caches( self, batch_size: int = 1, device: torch.device = torch.device("cpu"), ) -> List[Tensor]: - """Get initial states. + """Get initial caches. - A list of cached tensors of all encoder layers. For layer-i, states[i*5:(i+1)*5] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + A list of cached tensors of all encoder layers. For layer-i, caches[i*5:(i+1)*5] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len). """ - states = [] + caches = [] for i, module in enumerate(self.encoders): num_layers = module.num_layers embed_dim = self.encoder_dim[i] @@ -399,37 +427,27 @@ def get_init_states( key_dim = self.query_head_dim[i] * num_heads value_dim = self.value_head_dim[i] * num_heads downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 # will be error. have to figure this out. - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( - device - ) - cached_nonlin_attn = torch.zeros( - 1, batch_size, downsample_left, nonlin_attn_head_dim - ).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - states += [ + + # (self.conv_params[i] + 1) // 2 is the size used in the depthwise causal conv. + conv_left_pad = (self.conv_params[i] + 1) // 2 - 1 + + for layer_idx, enc_layer in enumerate(module.layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim, device=device) + cached_value = torch.zeros(downsample_left, batch_size, value_dim, device=device) + cached_conv = torch.zeros(batch_size, embed_dim, conv_left_pad, device=device) + cached_norm_stats, cached_norm_len = enc_layer.norm.get_init_cache(batch_size) + cached_norm_stats = cached_norm_stats.to(device) + cached_norm_len = cached_norm_len.to(device) + + caches.extend([ cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + ]) - return states + return caches def pad_mask(mask: Optional[Tensor], seq_len: int): @@ -540,7 +558,6 @@ def __init__( self.norm = SequenceNorm(causal=causal) - def forward( self, src: Tensor, @@ -596,105 +613,85 @@ def forward( src = self.norm(src, src_key_padding_mask) return src - + def streaming_forward( self, src: Tensor, cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, + cached_value: Tensor, cached_conv: Tensor, + cached_norm_stats: Tensor, + cached_norm_len: Tensor, left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Pass the input through the encoder layer in streaming forward mode. Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) + cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim) + cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim) + cached_conv: cached left context for the convolution module, of shape (batch_size, channels, left_pad) + cached_norm_stats: cached SequenceNorm stats, of shape (batch_size,) + cached_norm_len: cached SequenceNorm length, scalar. left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); + True means masked position. May be None. Returns: - x, with the same shape as src - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 + - updated cached_value - updated cached_conv + - updated cached_norm_stats + - updated cached_norm_len """ src_orig = src - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) + src_pre_ff1 = src + chunk_mask = None if src_key_padding_mask is None else src_key_padding_mask[:, left_context_len:] - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, - ) - src = src + na + src = src + self.feed_forward1(src, src_key_padding_mask=chunk_mask) - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, + # may try changing src_pre_ff1 to src or vice versa. + self_attn_out, cached_key, cached_value = self.self_attn.streaming_forward( + x_qkp=src_pre_ff1, + x_vg=src, left_context_len=left_context_len, + cached_key=cached_key, + cached_value=cached_value, + key_padding_mask=src_key_padding_mask, ) - src = src + self_attn + src = src + self_attn_out src_conv, cached_conv = self.conv_module.streaming_forward( - src, + 3.0 * src, cache=cached_conv, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + src_key_padding_mask=chunk_mask, ) src = src + src_conv - src = src + self.feed_forward2(src) - + src = src + self.feed_forward2(src, src_key_padding_mask=chunk_mask) - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) - src = src + self_attn - - offset = (src - src_orig) * self.residual_scale + residual_scale = 0.25 + offset = (src - src_orig) * residual_scale src = src_orig + offset - src = self.norm(src) + src, cached_norm_stats, cached_norm_len = self.norm.streaming_forward( + src, + cached_stats_sum=cached_norm_stats, + cached_len=cached_norm_len, + ) return ( src, cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, + cached_value, cached_conv, + cached_norm_stats, + cached_norm_len, ) @@ -739,8 +736,6 @@ def __init__( self.copy_bypass = Identity() - - def forward( self, src: Tensor, @@ -802,23 +797,22 @@ def forward( # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, # because of some identities involving orthogonal matrices. - return src - def streaming_forward( self, src: Tensor, - states: List[Tensor], + caches: List[Tensor], left_context_len: int, - src_key_padding_mask: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. + r"""Pass the input through the encoder layers in turn in streaming mode. Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*5:(i+1)*5] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv). + caches: list of cached tensors of N encoder layers. For layer-i, + caches[i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len). left_context_len: Number of left context frames. src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); True means masked position. @@ -826,51 +820,66 @@ def streaming_forward( Returns: - output, a Tensor with the same shape as src. - - updated states + - updated caches """ - num_channels = src.shape[-1] - layer_dim = self.layers[0].embed_dim - if num_channels > layer_dim: - src, bypass = src[..., :layer_dim], src[..., layer_dim:] + src_orig_fulldim = src + + # project to layer dim. + src = self.proj(src) + + num_layers = len(self.layers) + assert len(caches) == num_layers * 5 + + residual_scale = self.residual_scales[0] + input_scale = self.input_scale - new_states = [] + src_with_bypass = residual_scale * src + src = input_scale * src + + new_caches = [] for i, mod in enumerate(self.layers): ( cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, + cached_value, cached_conv, - ) = states[i * 5 : (i + 1) * 5] + cached_norm_stats, + cached_norm_len, + ) = caches[i * 5 : (i + 1) * 5] + ( src, new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, + new_cached_value, new_cached_conv, + new_cached_norm_stats, + new_cached_norm_len, ) = mod.streaming_forward( src, cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, + cached_value=cached_value, cached_conv=cached_conv, + cached_norm_stats=cached_norm_stats, + cached_norm_len=cached_norm_len, left_context_len=left_context_len, src_key_padding_mask=src_key_padding_mask, ) - new_states += [ + + layer_residual_scale = self.residual_scales[i + 1] + + src_with_bypass = src_with_bypass + layer_residual_scale * src + + new_caches.extend([ new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, + new_cached_value, new_cached_conv, - ] + new_cached_norm_stats, + new_cached_norm_len, + ]) - if num_channels > layer_dim: - src = torch.cat((src, bypass), dim=-1) + offset = src_with_bypass + src = src_orig_fulldim + self.proj(offset, transpose=True) - return src, new_states + return src, new_caches class ResidualModule(nn.Module): @@ -1074,7 +1083,6 @@ def __init__( self.vg_in_proj = ScaledLinear(embed_dim, 2 * num_heads * value_head_dim, initial_scale=0.1, bias=True) - self.copy_v = nn.Identity() # diagnostics. self.sigmoid = nn.Sigmoid() @@ -1083,8 +1091,6 @@ def __init__( num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 ) - - def forward( self, x_qkp: Tensor, @@ -1169,8 +1175,6 @@ def forward( 0.1 * aux_loss_scale, key_padding_mask, self.name) - - # We use our own version of softmax, defined in scaling.py, which should # save a little of the memory used in backprop by, if we are in # automatic mixed precision mode (amp / autocast), by only storing the @@ -1203,7 +1207,6 @@ def forward( .view(seq_len, batch_size, num_heads * value_head_dim) ) - if self.training: # don't let the sigmoid values get too extreme, limit to -2..2. g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) @@ -1212,142 +1215,101 @@ def forward( v = v * self.sigmoid(g) v = self.out_proj(v) return v - - def streaming_forward_weights( # TODO: fix and test, needs to be combined with value and gating stuff, - # see streaming_forward_vg which I took from the old class. + + def streaming_forward( self, - x: Tensor, - cached_key: Tensor, + x_qkp: Tensor, + x_vg: Tensor, left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: + cached_key: Tensor, + cached_value: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor]: r""" Args: - x: input of shape (seq_len, batch_size, embed_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - + x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, + keys and positions. + x_vg: input of shape (seq_len, batch_size, embed_dim), that is used for the values + and gates. May be the same as x_qk. + left_context_len: length of the cached left context. + cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim). + cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim). + key_padding_mask: a bool tensor of shape (batch_size, left_context_len + seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. + - attention output, of shape (seq_len, batch_size, embed_dim) + - updated cached_key, of shape (left_context_len, batch_size, key_dim) + - updated cached_value, of shape (left_context_len, batch_size, value_dim) """ - x = self.in_proj(x) query_head_dim = self.query_head_dim num_heads = self.num_heads + x_qkp = self.in_norm(x_qkp) + x_qkp = self.qkp_in_proj(x_qkp) - seq_len, batch_size, _ = x.shape + seq_len, batch_size, _ = x_qkp.shape query_dim = query_head_dim * num_heads # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] + q = x_qkp[..., 0:query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] + p = x_qkp[..., 2 * query_dim:] - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, ( - cached_key.shape[0], - left_context_len, - ) + # append the cached key to the current key, and update the cache + assert cached_key.shape[0] == left_context_len, (cached_key.shape, left_context_len) k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] + kv_len = k.shape[0] + cached_key = k[kv_len - left_context_len:] q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + k = k.reshape(kv_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, -1) # time1 refers to target, time2 refers to source. q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + k = k.permute(2, 1, 3, 0) # (head, batch, query_head_dim, time2) - attn_scores = torch.matmul(q, k) - - assert attn_scores.shape == ( - num_heads, - batch_size, - seq_len, - k_len, - ), attn_scores.shape + attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) + p = p.permute(1, 2, 0, 3) + pos_scores = self.rel_pos(p, left_context_len) # (batch, head, time1, time2) + attn_scores = attn_scores + pos_scores.permute(1, 0, 2, 3) + assert attn_scores.shape == (num_heads, batch_size, seq_len, kv_len) - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. - attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, aux_loss_scale, - key_padding_mask, self.name) + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, kv_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1), -1000) attn_weights = attn_scores.softmax(dim=-1) - return attn_weights, cached_key - - def streaming_forward_vg( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape[0], - left_context_len, - ) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] + # append the cached value to the current value, and update the cache + assert cached_value.shape[0] == left_context_len, (cached_value.shape, left_context_len) + v = torch.cat([cached_value, v], dim=0) + cached_value = v[kv_len - left_context_len:] - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] + v = v.reshape(kv_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + value_head_dim = v.shape[-1] + # now v: (num_heads, batch_size, kv_len, value_head_dim) # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) + v = torch.matmul(attn_weights, v) # v: (num_heads, batch_size, seq_len, value_head_dim) - x = ( - x.permute(2, 1, 0, 3) + v = ( + v.permute(2, 1, 0, 3) .contiguous() .view(seq_len, batch_size, num_heads * value_head_dim) ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val + v = v * self.sigmoid(g) + v = self.out_proj(v) + return v, cached_key, cached_value def _print_attn_entropy(self, attn_weights: Tensor): # attn_weights: (num_heads, batch_size, seq_len, seq_len) @@ -1366,7 +1328,6 @@ def _print_attn_entropy(self, attn_weights: Tensor): ) - class PenalizeLargeAttentionScores(torch.autograd.Function): @staticmethod def forward( @@ -1555,7 +1516,6 @@ def compute_angular_freq_basis_triangular(freqs: Tensor, return torch.stack((re, im), dim=-1).to(dtype) - class RelPosScores(nn.Module): def __init__(self, num_heads: int, @@ -1572,60 +1532,56 @@ def __init__(self, for _ in range(10): self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) - log_freqs = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs) freqs = math.pi * (log_freqs.exp() - low_freq_factor) # these range from 0 to pi. freqs[0] = 0.0 # in case of roundoff (it should be 0, mathematically) self.register_buffer('freqs', freqs, persistent=False) - - def forward(self, p: Tensor) -> Tensor: + def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: """ Compute and return unnormalized log scores for relative position. Args: p: these are the position-queries, of shape (batch_size, num_heads, seq_len, pos_dim) (they are obtained via projection, just like the queries). + left_context_len: length of left context, must be 0 for non-streaming forward and > 0 for streaming forward. Returns: - scores: (batch_size, num_heads, dest_seq_len, src_seq_len), - - where dest_seq_len and src_seq_len are numerically equal to seq_len but dest_seq_len relates to the - query and src_seq_len to the key. + scores: (batch_size, num_heads, dest_seq_len, src_seq_len), where dest_seq_len relates to the + query and src_seq_len to the key. + In non-streaming forward, dest_seq_len and src_seq_len are numerically equal to seq_len; + in streaming forward, dest_seq_len is seq_len and src_seq_len is seq_len + left_context_len. """ - (batch_size, num_heads, seq_len, pos_dim) = p.shape - - freqs = self.freqs # base freqs - t = torch.arange(-(seq_len - 1), seq_len, device=p.device) + t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=p.device) basis = compute_angular_freq_basis_triangular(freqs, t, scale=False) - # basis: (2 * seq_len - 1, num_freqs, 2) + # basis: (2 * seq_len + left_context_len - 1, num_freqs, 2) basis = basis.permute(0, 2, 1) # permute it because of how we did the low-pass initialization of weight, we want # the cos and sin parts to each be continuous ranges, not interleaved. - basis = basis.reshape(basis.shape[0], -1) # (2 * seq_len - 1, 2 * num_freqs) + basis = basis.reshape(basis.shape[0], -1) # (2 * seq_len + left_context_len - 1, 2 * num_freqs) x = torch.matmul(self.weight, basis.t()) - assert x.shape == (num_heads, pos_dim, 2 * seq_len - 1) + assert x.shape == (num_heads, pos_dim, 2 * seq_len + left_context_len - 1) - # with seq_len2 = 2 * seq_len - 1, - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # with seq_len2 = 2 * seq_len + left_context_len - 1, + # (batch, head, seq_len, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, seq_len, seq_len2) pos_weights = torch.matmul(p, x) # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. This is all copied from our old conformer/zipformer code. if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) + seq_len2 = pos_weights.shape[-1] + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(left_context_len + seq_len) rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) + pos_weights = pos_weights.reshape(-1, seq_len2) pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) - else: + pos_weights = pos_weights.reshape(batch_size, num_heads, seq_len, left_context_len + seq_len) + else: pos_weights = pos_weights.as_strided( - (batch_size, num_heads, seq_len, seq_len), + (batch_size, num_heads, seq_len, left_context_len + seq_len), ( pos_weights.stride(0), pos_weights.stride(1), @@ -1634,11 +1590,9 @@ def forward(self, p: Tensor) -> Tensor: ), storage_offset=pos_weights.stride(3) * (seq_len - 1), ) - return pos_weights - class FftConv(nn.Module): def __init__(self, num_channels: int, @@ -1682,8 +1636,6 @@ def forward(self, return x - - class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer2 model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py @@ -1713,7 +1665,6 @@ def __init__( # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. - self.activation1 = Identity() # for diagnostics self.sigmoid1 = nn.Sigmoid() @@ -1722,7 +1673,10 @@ def __init__( self.activation2 = Identity() # for diagnostics - self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) + if not causal: + self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) + else: + self.depthwise_conv = ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1748,9 +1702,7 @@ def forward( Returns: Tensor: Output tensor (#time, batch, channels). - """ - # x: (time, batch, channels) # Caution: this module is not completely # invariant to the number of frames each sequence is padded with, since @@ -1758,7 +1710,7 @@ def forward( if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = self.in_proj(x) # (time, batch, 2*channels) + x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) x, s, y = x.chunk(3, dim=2) s = self.sigmoid1(s) @@ -1767,81 +1719,87 @@ def forward( x = x * s x = self.activation2(x) # identity - x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) + if self.causal: + # Not support exporting a model for simulated streaming decoding + assert not torch.jit.is_scripting() and not torch.jit.is_tracing() + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) + # for the causal version, we don't use fft-conv + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x, chunk_size=chunk_size) + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) + else: + x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) x = x * y x = self.out_proj(x) # (time, batch, channels) return x - - def repeat_in_padding(self, x, mask): - # repeats elements of x in the padding region, circularly as much as possible; - # the discontinuity between the ones that circularly repeat from the end and - # those that circularly repeat from the beginning is in the middle of the padding - # region. - - # x: (seq_len, batch_size, num_channels) - (batch_size, seq_len) = mask.shape - - seq_lengths = (~mask).to(torch.int64).sum(dim=1, keepdim=True) # (batch_size, 1) - pad_len = seq_len - seq_lengths - arange = torch.arange(seq_len, device=mask.device) - - # "mid" gives the index of the midpoint of the padding region after each sequence. - mid = (seq_lengths + seq_len) // 2 # mid: (batch_size, 1) - - src_index = torch.where(arange >= mid, arange - pad_len, arange) % seq_lengths - # src_index: (batch_size, seq_len) - - src_index = src_index.t().unsqueeze(-1).expand_as(x) - # src_index: (seq_len, batch_size, num_channels) - x = torch.gather(x, dim=0, index=src_index) - return x - - - def streaming_forward( self, x: Tensor, cache: Tensor, - src_key_padding_mask: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. + """Compute convolution module. Args: x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) + cache: cached left context for depthwise_conv, of shape + (#batch, channels, left_pad) src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. + (batch, #time), contains True in masked positions. Returns: - Output tensor (#time, batch, channels). - Updated cache (#batch, channels, left_pad) """ + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = self.in_proj(x) # (time, batch, 2*channels) + x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) + x, s, y = x.chunk(3, dim=2) + s = self.sigmoid1(s) + y = self.sigmoid2(y) x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - + + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) - x = x.permute(2, 0, 1) # (time, batch, channels) - + x = x * y x = self.out_proj(x) # (time, batch, channels) return x, cache + def repeat_in_padding(self, x, mask): + # repeats elements of x in the padding region, circularly as much as possible; + # the discontinuity between the ones that circularly repeat from the end and + # those that circularly repeat from the beginning is in the middle of the padding + # region. + + # x: (seq_len, batch_size, num_channels) + (batch_size, seq_len) = mask.shape + + seq_lengths = (~mask).to(torch.int64).sum(dim=1, keepdim=True) # (batch_size, 1) + pad_len = seq_len - seq_lengths + arange = torch.arange(seq_len, device=mask.device) + + # "mid" gives the index of the midpoint of the padding region after each sequence. + mid = (seq_lengths + seq_len) // 2 # mid: (batch_size, 1) + + src_index = torch.where(arange >= mid, arange - pad_len, arange) % seq_lengths + # src_index: (batch_size, seq_len) + + src_index = src_index.t().unsqueeze(-1).expand_as(x) + # src_index: (seq_len, batch_size, num_channels) + x = torch.gather(x, dim=0, index=src_index) + return x + class ScalarMultiply(nn.Module): def __init__(self, scale: float): @@ -1884,6 +1842,111 @@ def _test_zipformer_main(causal: bool = False): ) x_ # to remove flake8 warnings + logging.info(f"Zipformer forward test passed, causal={causal}") + + +def _test_zipformer_streaming(): + input_dim = 50 + batch_size = 2 + chunk_size = 32 + num_chunks = 3 + tail_chunk_size = 8 + seq_len = chunk_size * num_chunks + tail_chunk_size + left_context_frames = 128 + + model = Zipformer2( + input_dim=input_dim, + encoder_dim=(64, 96, 128, 96), + num_heads=(4, 4, 4, 4), + conv_params=(7, 7, 7, 7), + downsampling_factor=(1, 2, 4, 2), + causal=True, + chunk_size=(chunk_size,), + left_context_frames=(left_context_frames,), + ) + + model.eval() + + x_full = torch.randn(seq_len, batch_size, input_dim) + x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) + + with torch.no_grad(): + out_full, out_lens_full = model(x_full, x_lens_full) + + caches = model.get_init_caches(batch_size=batch_size) + + out_chunks = [] + out_offset = 0 + processed_lens = torch.full((batch_size,), 0, dtype=torch.int64) + + for i in range(num_chunks): + start = i * chunk_size + end = start + chunk_size + x_chunk = x_full[start:end] + x_lens = torch.full((batch_size,), chunk_size, dtype=torch.int64) + + src_key_padding_mask = make_pad_mask(x_lens) + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_frames).expand(batch_size, left_context_frames) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + out_chunk, out_lens, caches = model.streaming_forward( + x=x_chunk, + x_lens=x_lens, + caches=caches, + src_key_padding_mask=src_key_padding_mask, + ) + out_chunks.append(out_chunk) + + out_chunk_len = out_chunk.shape[0] + expected_out = out_full[out_offset : out_offset + out_chunk_len] + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) + logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") + assert torch.allclose(expected_out, out_chunk, atol=2e-5), f"Chunk {i+1} outputs do not match! Max diff: {diff_chunk}" + + out_offset += out_chunk_len + + x_tail = x_full[num_chunks * chunk_size:] + x_lens_tail = torch.full((batch_size,), tail_chunk_size, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens_tail) + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_frames).expand(batch_size, left_context_frames) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + processed_lens = processed_lens + x_lens_tail + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + out_tail, out_lens_tail, caches = model.streaming_forward( + x=x_tail, + x_lens=x_lens_tail, + caches=caches, + src_key_padding_mask=src_key_padding_mask, + ) + out_chunks.append(out_tail) + + out_tail_len = out_tail.shape[0] + expected_out_tail = out_full[out_offset : out_offset + out_tail_len] + diff_tail = torch.max(torch.abs(expected_out_tail - out_tail)) + logging.info(f"Tail Chunk | Input: {x_tail.shape} -> Output: {out_tail.shape} | Max diff: {diff_tail}") + assert torch.allclose(expected_out_tail, out_tail, atol=2e-5), f"Tail Chunk outputs do not match! Max diff: {diff_tail}" + out_offset += out_tail_len + + out_stream_cat = torch.cat(out_chunks, dim=0) + + diff = torch.max(torch.abs(out_full - out_stream_cat)) + logging.info(f"Max abs diff between full forward and streaming forward: {diff}") + + assert torch.allclose(out_full, out_stream_cat, atol=2e-5), f"Outputs do not match! Max diff: {diff}" + + logging.info("Zipformer streaming_forward test passed") if __name__ == "__main__": @@ -1892,3 +1955,4 @@ def _test_zipformer_main(causal: bool = False): torch.set_num_interop_threads(1) _test_zipformer_main(False) _test_zipformer_main(True) + _test_zipformer_streaming() From c768e4853e3722dd9bfa58fba2083f1db9211433 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Mar 2026 22:00:00 +0800 Subject: [PATCH 0937/1191] Do not double count the scalar component of row_col_scale. --- egs/librispeech/ASR/zipformer/optim.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 1489744dfd..d792a23ff5 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -443,8 +443,14 @@ def min_sum_scale(x, y): delta2_buffer0 = state["delta2_buffer0"] delta2_buffer1 = state["delta2_buffer1"] - # we'll scale both before and after the cubing - row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt() * (delta2_buffer1 + eps).sqrt()) + # we'll scale both before and after the cubing. + # the lines where we divide by sqrt of the mean are so we don't double + # count the scalar component of this. + factor0 = (delta2_buffer0 + eps).sqrt() + factor0 = factor0 / factor0.mean(dim=1, keepdim=True).sqrt() + factor1 = (delta2_buffer1 + eps).sqrt() + factor1 = factor1 / factor1.mean(dim=2, keepdim=True).sqrt() + row_col_scale = 1. / (factor0 * factor1) x3 = x3 * row_col_scale #note, we are before computing the cubed part. From 0cf08808061f2f36737d82d7c80c163e4ea29a3a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 12:27:45 +0800 Subject: [PATCH 0938/1191] Decrease excess_scale from 2.5 to 2.0, so stronger x3 decay. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index d792a23ff5..099077fcf7 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -434,7 +434,7 @@ def min_sum_scale(x, y): linear_decay_scale = 0.25 d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) - excess_scale = 2.5 + excess_scale = 2.0 x3 = d * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times if "delta2_buffer0" not in state: From aa87770de9d5d0a05e20204628c2d29c0557f081 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 13:01:06 +0800 Subject: [PATCH 0939/1191] Add assert statement for stored_delta --- egs/librispeech/ASR/zipformer/optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 099077fcf7..7975078a48 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -482,6 +482,8 @@ def min_sum_scale(x, y): d = d * row_col_scale # fully-normalized d stored_delta = d.reshape(*stored_delta.shape) # note: permanent buffer is not updated. + + assert torch.all(stored_delta - stored_delta == 0.0) else: stored_delta.mul_(beta1) From 52bf6b280c12619c0a0d7a5d45b8438944a15713 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 14:18:22 +0800 Subject: [PATCH 0940/1191] Make assertion conditional. --- egs/librispeech/ASR/zipformer/optim.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 7975078a48..0d620ffa6c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -446,11 +446,11 @@ def min_sum_scale(x, y): # we'll scale both before and after the cubing. # the lines where we divide by sqrt of the mean are so we don't double # count the scalar component of this. - factor0 = (delta2_buffer0 + eps).sqrt() + factor0 = 1.0 / (delta2_buffer0 + eps).sqrt() factor0 = factor0 / factor0.mean(dim=1, keepdim=True).sqrt() - factor1 = (delta2_buffer1 + eps).sqrt() + factor1 = 1.0 / (delta2_buffer1 + eps).sqrt() factor1 = factor1 / factor1.mean(dim=2, keepdim=True).sqrt() - row_col_scale = 1. / (factor0 * factor1) + row_col_scale = (factor0 * factor1) x3 = x3 * row_col_scale #note, we are before computing the cubed part. @@ -464,6 +464,8 @@ def min_sum_scale(x, y): # we divide x3 by row_col_scale to "un-normalize". d.add_(x3 * alpha / row_col_scale) + d.clamp_(min=-10., max=10.) # avoid divergence + if random.random() < 0.0005: rel_scale = (d ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) logging.info(f"shape={stored_delta.shape}, rel_scale = {rel_scale.item()}") @@ -483,7 +485,15 @@ def min_sum_scale(x, y): stored_delta = d.reshape(*stored_delta.shape) # note: permanent buffer is not updated. - assert torch.all(stored_delta - stored_delta == 0.0) + + def s(x): + if x.ndim <= 1: + return x.to('cpu') + else: + return (x ** 2).mean(dim=list(range(1, x.ndim))).sqrt().to('cpu') + + if step < 100: + assert torch.all(stored_delta - stored_delta == 0.0), (step, s(stored_delta), s(delta), delta.shape,s(d), s(x3), s(delta2_buffer0), s(delta2_buffer1), s(factor0), s(factor1), s(row_col_scale), s(d2), s(row_col_scale)) else: stored_delta.mul_(beta1) From da5416ab2992270f0899a6f80bdda76bdac1d681 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 14:39:28 +0800 Subject: [PATCH 0941/1191] Revert prev changes to factor creation --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 0d620ffa6c..bde173a729 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -448,9 +448,9 @@ def min_sum_scale(x, y): # count the scalar component of this. factor0 = 1.0 / (delta2_buffer0 + eps).sqrt() factor0 = factor0 / factor0.mean(dim=1, keepdim=True).sqrt() - factor1 = 1.0 / (delta2_buffer1 + eps).sqrt() + factor1 = (delta2_buffer1 + eps).sqrt() factor1 = factor1 / factor1.mean(dim=2, keepdim=True).sqrt() - row_col_scale = (factor0 * factor1) + row_col_scale = 1. / (factor0 * factor1) x3 = x3 * row_col_scale #note, we are before computing the cubed part. From 0b594dfc64516e67f213ae9425e62060eb5b8ed4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 16:11:46 +0800 Subject: [PATCH 0942/1191] Bug fix in how coefficient of x3 is computed. --- egs/librispeech/ASR/zipformer/optim.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index bde173a729..bc7e783eb0 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -434,8 +434,6 @@ def min_sum_scale(x, y): linear_decay_scale = 0.25 d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) - excess_scale = 2.0 - x3 = d * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times if "delta2_buffer0" not in state: state["delta2_buffer0"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) @@ -446,13 +444,15 @@ def min_sum_scale(x, y): # we'll scale both before and after the cubing. # the lines where we divide by sqrt of the mean are so we don't double # count the scalar component of this. - factor0 = 1.0 / (delta2_buffer0 + eps).sqrt() + factor0 = (delta2_buffer0 + eps).sqrt() factor0 = factor0 / factor0.mean(dim=1, keepdim=True).sqrt() factor1 = (delta2_buffer1 + eps).sqrt() factor1 = factor1 / factor1.mean(dim=2, keepdim=True).sqrt() row_col_scale = 1. / (factor0 * factor1) - x3 = x3 * row_col_scale #note, we are before computing the cubed part. + excess_scale = 2.0 + d_scaled = d * row_col_scale + x3 = d_scaled * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times compute_prod3_inplace(x3) # actually computes 3rd power of its arg divided by max(rows, cols)**2 # the factor of 0.5 says we only want to go, at most, half the way to the point which @@ -460,12 +460,10 @@ def min_sum_scale(x, y): # and having the direction change sign, in a situation where we are not dominated by # the largest singular value; or to prevent the largest singular value from going to # zero if it does dominate. - alpha = (0.5 * min_sum_scale(d, x3)).clamp(min=-1) + alpha = (0.5 * min_sum_scale(d_scaled, x3)).clamp(min=-1) # we divide x3 by row_col_scale to "un-normalize". d.add_(x3 * alpha / row_col_scale) - d.clamp_(min=-10., max=10.) # avoid divergence - if random.random() < 0.0005: rel_scale = (d ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) logging.info(f"shape={stored_delta.shape}, rel_scale = {rel_scale.item()}") @@ -473,7 +471,7 @@ def min_sum_scale(x, y): if not stored_delta.untyped_storage() is d.untyped_storage(): stored_delta[:] = d.reshape(*stored_delta.shape) - beta = beta1 # use this beta for row/col scales + beta = beta1 # use this beta for row/col scales d = d * row_col_scale # half-normalized d assumed_scale = 0.5 * ((1 - beta1**2)**-0.5) # assumed scalie of d d2 = (d / assumed_scale) ** 2 From 5f28d47dd9a3185aed6fe5503416d416812ec674 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 16:17:27 +0800 Subject: [PATCH 0943/1191] Revert excess_scale to 2.5. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index bc7e783eb0..f4aa7f653b 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -450,7 +450,7 @@ def min_sum_scale(x, y): factor1 = factor1 / factor1.mean(dim=2, keepdim=True).sqrt() row_col_scale = 1. / (factor0 * factor1) - excess_scale = 2.0 + excess_scale = 2.50 d_scaled = d * row_col_scale x3 = d_scaled * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times From e6b6bcd0e4544d01b36a9f123f2cfd50c49b3e4b Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 8 Mar 2026 19:25:16 +0800 Subject: [PATCH 0944/1191] Simplify the causal conv module by removing the within-chunk conv branch --- egs/librispeech/ASR/zipformer/subsampling.py | 71 +++++++++++++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 62 ++++++++--------- 2 files changed, 98 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 0959f417fc..49ab427764 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from typing import Tuple, Optional @@ -50,7 +51,6 @@ def forward(self, x: Tensor) -> Tensor: return x + noise_scale * torch.randn_like(x) - class ConvNeXt(nn.Module): """ Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf @@ -335,3 +335,72 @@ def get_init_cache( cache = torch.zeros(batch_size, channels, left_pad, freq, device=device) return cache + + +def _test_conv2d_subsampling_streaming(): + logging.info("Testing Conv2dSubsampling streaming equivalence...") + + batch_size = 2 + idim = 80 + odim = 256 + + model = Conv2dSubsampling( + in_channels=idim, + out_channels=odim, + causal=True + ) + + model.eval() + + out_chunk_size = 32 + in_chunk_size = out_chunk_size * 2 + 7 + in_shift = out_chunk_size * 2 + + num_chunks = 10 + + seq_len = num_chunks * in_shift + 7 + + x_full = torch.randn(batch_size, seq_len, idim) + x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) + + with torch.no_grad(): + out_full, out_lens_full = model(x_full, x_lens_full) + + cache = model.get_init_cache(batch_size=batch_size) + + out_chunks = [] + out_offset = 0 + + for i in range(num_chunks): + start = i * in_shift + end = start + in_chunk_size + x_chunk = x_full[:, start:end, :] + x_lens_chunk = torch.full((batch_size,), in_chunk_size, dtype=torch.int64) + + out_chunk, out_lens_chunk, cache = model.streaming_forward( + x_chunk, x_lens_chunk, cache + ) + out_chunks.append(out_chunk) + + out_chunk_len = out_chunk.shape[1] + expected_out = out_full[:, out_offset : out_offset + out_chunk_len, :] + + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) + logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") + + assert torch.allclose(expected_out, out_chunk, atol=1e-4), f"Chunk {i+1} mismatch! max diff: {diff_chunk}" + out_offset += out_chunk_len + + out_stream_cat = torch.cat(out_chunks, dim=1) + diff_total = torch.max(torch.abs(out_full - out_stream_cat)) + logging.info(f"Total Max Diff between full forward and streaming: {diff_total}") + assert torch.allclose(out_full, out_stream_cat, atol=1e-4), "Total outputs do not match!" + + logging.info("Passed") + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_conv2d_subsampling_streaming() \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 16b733b0a5..4ec9ed9403 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -301,7 +301,7 @@ def _get_attn_mask( chunk_size * left_context_chunks >= (self.conv_params[i] // 2) * self.downsampling_factor[i] for i in range(num_encoders) - ) + ) # TODO: could test remove this else: left_context_chunks = 1000000 @@ -427,9 +427,7 @@ def get_init_caches( key_dim = self.query_head_dim[i] * num_heads value_dim = self.value_head_dim[i] * num_heads downsample_left = self.left_context_frames[0] // ds - - # (self.conv_params[i] + 1) // 2 is the size used in the depthwise causal conv. - conv_left_pad = (self.conv_params[i] + 1) // 2 - 1 + conv_left_pad = self.conv_params[i] - 1 for layer_idx, enc_layer in enumerate(module.layers): cached_key = torch.zeros(downsample_left, batch_size, key_dim, device=device) @@ -1676,7 +1674,15 @@ def __init__( if not causal: self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) else: - self.depthwise_conv = ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=0, # will pad manually, on one side. + bias=True, + ) + self.left_pad = kernel_size - 1 self.out_proj = ActivationDropoutAndLinear( bottleneck_dim, @@ -1726,7 +1732,10 @@ def forward( # for the causal version, we don't use fft-conv if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x, chunk_size=chunk_size) + x_shape = x.shape + x = torch.nn.functional.pad(x, (self.left_pad, 0)) + x = self.depthwise_conv(x) + assert x.shape == x_shape, (x.shape, x_shape) x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) else: x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) @@ -1768,7 +1777,16 @@ def streaming_forward( x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x_shape = x.shape + assert cache.shape[-1] == self.left_pad, (cache.shape[-1], self.left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -self.left_pad:] + + x = self.depthwise_conv(x) + assert x.shape == x_shape, (x.shape, x_shape) + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) x = x * y @@ -1776,30 +1794,6 @@ def streaming_forward( return x, cache - def repeat_in_padding(self, x, mask): - # repeats elements of x in the padding region, circularly as much as possible; - # the discontinuity between the ones that circularly repeat from the end and - # those that circularly repeat from the beginning is in the middle of the padding - # region. - - # x: (seq_len, batch_size, num_channels) - (batch_size, seq_len) = mask.shape - - seq_lengths = (~mask).to(torch.int64).sum(dim=1, keepdim=True) # (batch_size, 1) - pad_len = seq_len - seq_lengths - arange = torch.arange(seq_len, device=mask.device) - - # "mid" gives the index of the midpoint of the padding region after each sequence. - mid = (seq_lengths + seq_len) // 2 # mid: (batch_size, 1) - - src_index = torch.where(arange >= mid, arange - pad_len, arange) % seq_lengths - # src_index: (batch_size, seq_len) - - src_index = src_index.t().unsqueeze(-1).expand_as(x) - # src_index: (seq_len, batch_size, num_channels) - x = torch.gather(x, dim=0, index=src_index) - return x - class ScalarMultiply(nn.Module): def __init__(self, scale: float): @@ -1849,7 +1843,7 @@ def _test_zipformer_streaming(): input_dim = 50 batch_size = 2 chunk_size = 32 - num_chunks = 3 + num_chunks = 10 tail_chunk_size = 8 seq_len = chunk_size * num_chunks + tail_chunk_size left_context_frames = 128 @@ -1858,7 +1852,7 @@ def _test_zipformer_streaming(): input_dim=input_dim, encoder_dim=(64, 96, 128, 96), num_heads=(4, 4, 4, 4), - conv_params=(7, 7, 7, 7), + conv_params=(31, 31, 15, 31), downsampling_factor=(1, 2, 4, 2), causal=True, chunk_size=(chunk_size,), @@ -1946,7 +1940,7 @@ def _test_zipformer_streaming(): assert torch.allclose(out_full, out_stream_cat, atol=2e-5), f"Outputs do not match! Max diff: {diff}" - logging.info("Zipformer streaming_forward test passed") + logging.info("Passed") if __name__ == "__main__": From 2890cd66e68fd057e26668d1220b8eba0ca29ca5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 21:23:36 +0800 Subject: [PATCH 0945/1191] Remove factor of 0.5 in assumed_scale. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index f4aa7f653b..2688130db0 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -473,7 +473,7 @@ def min_sum_scale(x, y): beta = beta1 # use this beta for row/col scales d = d * row_col_scale # half-normalized d - assumed_scale = 0.5 * ((1 - beta1**2)**-0.5) # assumed scalie of d + assumed_scale = ((1 - beta1**2)**-0.5) # assumed scale of d d2 = (d / assumed_scale) ** 2 if random.random() < 0.001: logging.info(f"shape={stored_delta.shape}, mean of normalized d2 is {d2.mean().item()}") From a3feaf487ebf0f03470b6a1ebf5bf018623ab65c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 8 Mar 2026 22:25:12 +0800 Subject: [PATCH 0946/1191] Introduce delta_scale_buffer to ensure a constant scale of final delta; set excess_scale = 5.0 --- egs/librispeech/ASR/zapformer/train.py | 4 +-- egs/librispeech/ASR/zipformer/optim.py | 34 +++++++++++++++++--------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index f59e1160de..9f7ed41a6a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -425,7 +425,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.001, help="The base learning rate." + "--base-lr", type=float, default=0.00065, help="The base learning rate." ) parser.add_argument( @@ -1362,7 +1362,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - wd=12.5, + wd=25, scale_limits=(1.0, 4.0), ) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 2688130db0..90af94d2a7 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -426,31 +426,37 @@ def min_sum_scale(x, y): # alpha = xy / yy return -xy / (yy + eps) - stored_delta.add_(delta) if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) - # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. - # this should be configurable. - linear_decay_scale = 0.25 - - d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) if "delta2_buffer0" not in state: state["delta2_buffer0"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) state["delta2_buffer1"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) + state["delta_scale_buffer"] = torch.ones(d.shape[0], 1, 1, device=d.device, dtype=d.dtype) + delta2_buffer0 = state["delta2_buffer0"] delta2_buffer1 = state["delta2_buffer1"] + delta_scale_buffer = state["delta_scale_buffer"] + d.add_(delta.reshape(*d.shape) * delta_scale_buffer) + + + # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. + # this should be configurable. + linear_decay_scale = 0.25 + + d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) + # we'll scale both before and after the cubing. # the lines where we divide by sqrt of the mean are so we don't double # count the scalar component of this. factor0 = (delta2_buffer0 + eps).sqrt() - factor0 = factor0 / factor0.mean(dim=1, keepdim=True).sqrt() + factor0_mean = factor0.mean(dim=1, keepdim=True) + factor0 = factor0 / factor0_mean factor1 = (delta2_buffer1 + eps).sqrt() - factor1 = factor1 / factor1.mean(dim=2, keepdim=True).sqrt() row_col_scale = 1. / (factor0 * factor1) - excess_scale = 2.50 + excess_scale = 5.0 d_scaled = d * row_col_scale x3 = d_scaled * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times @@ -475,6 +481,9 @@ def min_sum_scale(x, y): d = d * row_col_scale # half-normalized d assumed_scale = ((1 - beta1**2)**-0.5) # assumed scale of d d2 = (d / assumed_scale) ** 2 + + delta_scale_buffer.add_((1 - d2.mean(dim=(1, 2), keepdim=True)).sign(), alpha=0.01) # infinite gain to make factor0_mean equal to 1 + if random.random() < 0.001: logging.info(f"shape={stored_delta.shape}, mean of normalized d2 is {d2.mean().item()}") delta2_buffer0.mul_(beta).add_(d2.mean(dim=2, keepdim=True), alpha=(1 - beta)) @@ -493,6 +502,7 @@ def s(x): if step < 100: assert torch.all(stored_delta - stored_delta == 0.0), (step, s(stored_delta), s(delta), delta.shape,s(d), s(x3), s(delta2_buffer0), s(delta2_buffer1), s(factor0), s(factor1), s(row_col_scale), s(d2), s(row_col_scale)) else: + stored_delta.add_(delta) stored_delta.mul_(beta1) @@ -1112,11 +1122,11 @@ def _test_transformed_adam(hidden_dim: int): for _ in range(20) ] - lr = 0.001 + lr = 0.0006 if test == 0: - optim = TransformedAdam(m.named_parameters(), lr=lr, wd=12, eps=1.0e-20, beta1=0.99) + optim = TransformedAdam(m.named_parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.99) elif test == 1: - optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=12, eps=1.0e-20, beta1=0.99) + optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.99) num_epochs = 180 From 176ac44ef9e037d5ae58715cdb8d6c2262f94841 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 11:15:54 +0800 Subject: [PATCH 0947/1191] Decrease excess_scale from 5.0 to 3.0. --- egs/librispeech/ASR/zipformer/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 90af94d2a7..8c96ec7f38 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -456,7 +456,7 @@ def min_sum_scale(x, y): factor1 = (delta2_buffer1 + eps).sqrt() row_col_scale = 1. / (factor0 * factor1) - excess_scale = 5.0 + excess_scale = 3.0 d_scaled = d * row_col_scale x3 = d_scaled * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times From aabd4f3e4b11af89ec14dc650ef216a9329519c6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 11:16:46 +0800 Subject: [PATCH 0948/1191] Decrease logging probs --- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8c96ec7f38..4ab74dfa06 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -470,7 +470,7 @@ def min_sum_scale(x, y): # we divide x3 by row_col_scale to "un-normalize". d.add_(x3 * alpha / row_col_scale) - if random.random() < 0.0005: + if random.random() < 0.0001: rel_scale = (d ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) logging.info(f"shape={stored_delta.shape}, rel_scale = {rel_scale.item()}") @@ -484,7 +484,7 @@ def min_sum_scale(x, y): delta_scale_buffer.add_((1 - d2.mean(dim=(1, 2), keepdim=True)).sign(), alpha=0.01) # infinite gain to make factor0_mean equal to 1 - if random.random() < 0.001: + if random.random() < 0.0001: logging.info(f"shape={stored_delta.shape}, mean of normalized d2 is {d2.mean().item()}") delta2_buffer0.mul_(beta).add_(d2.mean(dim=2, keepdim=True), alpha=(1 - beta)) delta2_buffer1.mul_(beta).add_(d2.mean(dim=1, keepdim=True), alpha=(1 - beta)) From 34b67a35d4f6f5cb942b7a210fc3680e378725fb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 12:02:23 +0800 Subject: [PATCH 0949/1191] Change LR scheduler to cosine scheduler with dual-purpose min_factor = 0.1 (linearly applied but makes final min_factor of lr schedule constant.) --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 9 +++++++-- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 58e120728e..0eb5744fac 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -120,15 +120,20 @@ def print_lr(self, is_verbose, group, lr): class CosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, - min_factor: float = 0.2, + min_factor: float = 0.1, **kwargs): super().__init__(*args, **kwargs) + # min_factor has two roles: it acts as a minimum relative learning rate + # (linearly applied, not with max); and it makes the final learning rate + # constant for the final (min_factor) of the schedule, compressing the + # cosine decay into the first (1 - min_factor) of the schedule. self.min_factor = min_factor def get_lr(self): progress = self.get_progress() + progress = min(1.0, progress / (1.0 - min_factor)) # clamp progress at 1.0 for final min_factor of schedule. factor = 0.5 * (1.0 + math.cos(math.pi * progress)) - factor = self.min_factor + (1. - self.min_factor) * factor + factor = self.min_factor + (1. - self.min_factor) * factor # apply min_factor linearly. return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 9f7ed41a6a..0106854086 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -78,7 +78,7 @@ from optim import TransformedAdam from combined_scheduler import CombinedLRScheduler, CosineLRScheduler try: - from combined_scheduler import LinearLRScheduler + from combined_scheduler import CosineLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -1373,7 +1373,7 @@ def lr_lambda(current_step): progress = current_step / total_steps return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - scheduler = LinearLRScheduler(optimizer, + scheduler = CosineLRScheduler(optimizer, batches_per_epoch=params.batches_per_epoch, num_epochs=params.num_epochs) From 549e73ab0009430bf9ff49118bdef9a5c6e1c98c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 13:37:16 +0800 Subject: [PATCH 0950/1191] Bug fix --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 0eb5744fac..54fb799643 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -131,7 +131,7 @@ def __init__(self, def get_lr(self): progress = self.get_progress() - progress = min(1.0, progress / (1.0 - min_factor)) # clamp progress at 1.0 for final min_factor of schedule. + progress = min(1.0, progress / (1.0 - self.min_factor)) # clamp progress at 1.0 for final min_factor of schedule. factor = 0.5 * (1.0 + math.cos(math.pi * progress)) factor = self.min_factor + (1. - self.min_factor) * factor # apply min_factor linearly. return [x * factor for x in self.base_lrs] From f1629006d019ab2d5b25d74899f1b08796e2fe3e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 15:34:30 +0800 Subject: [PATCH 0951/1191] Code improvement (making excess_scale,linear_decay_scale configurable); decrease linear_decay_scale from .25 to .2 --- egs/librispeech/ASR/zipformer/optim.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 4ab74dfa06..b82287f885 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -143,7 +143,6 @@ def base_step(group, state, grad): exp_avg_sq = torch.zeros(*stats_shape, device=grad.device, dtype=torch.float) state["exp_avg_sq"] = exp_avg_sq - mean_dims = list(range(1, grad.ndim)) grad2 = (grad ** 2) if len(mean_dims) > 0: @@ -394,7 +393,7 @@ def prod(l): def momentum_step(group, state, grad): - delta = base_step(group, state, grad) + delta = base_step(group, state, grad) # base_step just normalizes overall scale of tensor with a scalar estimate of its rms. # delta is the normalized gradient; the rms of delta should be around 1. lr = group["lr"] @@ -402,6 +401,9 @@ def momentum_step(group, state, grad): step = state["step"] beta1 = min(group["beta1"], 1. - 1. / (10. + 0.2 * step)) direct = group["direct"] + linear_decay_scale = group["linear_decay_scale"] + excess_scale = group["excess_scale"] + min_scale, max_scale = group["scale_limits"] try: @@ -440,10 +442,6 @@ def min_sum_scale(x, y): d.add_(delta.reshape(*d.shape) * delta_scale_buffer) - # decay by one quarter of the beta1-determined decay rate, leaving the rest to the x^3 decay. - # this should be configurable. - linear_decay_scale = 0.25 - d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) @@ -456,7 +454,6 @@ def min_sum_scale(x, y): factor1 = (delta2_buffer1 + eps).sqrt() row_col_scale = 1. / (factor0 * factor1) - excess_scale = 3.0 d_scaled = d * row_col_scale x3 = d_scaled * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times @@ -609,6 +606,8 @@ def __init__( lr=1e-03, beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) + linear_decay_scale=0.2, + excess_scale=3.0, beta2=0.98, wd=10, eps=1.0e-08, @@ -619,6 +618,8 @@ def __init__( lr=lr, beta1=beta1, direct=direct, + linear_decay_scale=linear_decay_scale, + excess_scale=excess_scale, beta2=beta2, eps=eps, wd=wd, @@ -1017,6 +1018,8 @@ def __init__( lr=1e-03, beta1=0.995, direct=0.05, # scale on bypass of momentum (beta1) + linear_decay_scale=0.2, + excess_scale=3.0, beta2=0.98, wd=10, eps=1.0e-08, @@ -1026,6 +1029,8 @@ def __init__( lr=lr, beta1=beta1, direct=direct, + linear_decay_scale=linear_decay_scale, + excess_scale=excess_scale, beta2=beta2, eps=eps, wd=wd, From 00de32517194669d53c158a35d103c1d701c3578 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 18:52:59 +0800 Subject: [PATCH 0952/1191] Reformulate optim.py in a cleaner way; normalize bypass stats. --- egs/librispeech/ASR/zipformer/optim.py | 361 ++++++++----------------- 1 file changed, 118 insertions(+), 243 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b82287f885..b77bc03b0d 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -126,149 +126,14 @@ def batched_params(self, param_group, group_params_names): - -def base_step(group, state, grad): - # computes basic Adam normalized-grad using beta2 (dividing by gradient stddev) only. no momentum yet. - # this normalied-grad is normalized only at the whole tensor level for now. - - beta2 = group["beta2"] - eps = group["eps"] - # p shape: (batch_size,) or (batch_size, 1, [1,..]) - try: - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) or (batch_size, 1, [1,..]) - except KeyError: - assert state["step"] < 2 - batch_size = grad.shape[0] - stats_shape = [batch_size] + [1] * (len(grad.shape) - 1) - exp_avg_sq = torch.zeros(*stats_shape, device=grad.device, dtype=torch.float) - state["exp_avg_sq"] = exp_avg_sq - - mean_dims = list(range(1, grad.ndim)) - grad2 = (grad ** 2) - if len(mean_dims) > 0: - grad2 = grad2.mean(dim=mean_dims, keepdim=True) - exp_avg_sq.mul_(beta2).add_(grad2, alpha=1 - beta2) - - # bias_correction2 is like in Adam. - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - denom = exp_avg_sq.sqrt().add_(eps) - - return grad / denom - - -def compute_prod5_inplace(x): # replaces x with x^3 / max(rows, cols), x is interpreted as a batch of matrices. - assert x.ndim >= 3 - - - if x.ndim > 3: - # each tensor in the batch has more than two dimensions. - # reshape to be like a batch of matrices. - # note: x.shape[0] is batch dimension. - if x.shape[1] > x.shape[-1]: - xr = x.reshape(x.shape[0], x.shape[1], -1) - else: - xr = x.reshape(x.shape[0], -1, x.shape[-1]) - compute_prod5_inplace(xr) - if not xr.untyped_storage() is x.untyped_storage(): - x[:] = xr.reshape(*x.shape) - return - if x.shape[1] > x.shape[2]: - xr = x.permute(0, 2, 1) - compute_prod5_inplace(xr) - if not xr.untyped_storage() is x.untyped_storage(): - x[:] = xr.permute(0, 2, 1) - return - - # avoid matrix multiplies by any dimensions that are too large. - max_dim = 1024 - if x.shape[1] > max_dim: - n = x.shape[1] - for divisor in range(2, 100): - if n % divisor == 0 and n // divisor <= max_dim: - xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) - compute_prod5_inplace(xr) - if not xr.untyped_storage() is x.untyped_storage(): - x[:] = xr.reshape(*x.shape) - return - # if no divisor worked, just continue. - - (batch_size, rows, cols) = x.shape # and rows <= cols - - x2 = torch.matmul(x, x.permute(0, 2, 1)) / max(rows, cols) - x4 = torch.matmul(x2, x2) - x5 = torch.matmul(x4, x) - - x[:] = x5 - - - - -def compute_prod5(x): - # computes matrix-matrix-matrix-matrix-matrix product of batch of matrices x, with reshaping if necessary; - # first divides x by max(num_rows, num_cols)^2 so its a kind of normalized 5th-product. - x = x.clone() - compute_prod5_inplace(x) - return x - - -def compute_prod3_inplace(x): # replaces x with x^3 / max(rows, cols), x is interpreted as a batch of matrices. - assert x.ndim >= 3 - - - if x.ndim > 3: - # each tensor in the batch has more than two dimensions. - # reshape to be like a batch of matrices. - # note: x.shape[0] is batch dimension. - if x.shape[1] > x.shape[-1]: - xr = x.reshape(x.shape[0], x.shape[1], -1) - else: - xr = x.reshape(x.shape[0], -1, x.shape[-1]) - compute_prod3_inplace(xr) - if not xr.untyped_storage() is x.untyped_storage(): - x[:] = xr.reshape(*x.shape) - return - if x.shape[1] > x.shape[2]: - xr = x.permute(0, 2, 1) - compute_prod3_inplace(xr) - if not xr.untyped_storage() is x.untyped_storage(): - x[:] = xr.permute(0, 2, 1) - return - - # avoid matrix multiplies by any dimensions that are too large. - max_dim = 1024 - if x.shape[1] > max_dim: - n = x.shape[1] - for divisor in range(2, 100): - if n % divisor == 0 and n // divisor <= max_dim: - xr = x.reshape(x.shape[0] * divisor, n // divisor, x.shape[2]) - compute_prod3_inplace(xr) - if not xr.untyped_storage() is x.untyped_storage(): - x[:] = xr.reshape(*x.shape) - return - # if no divisor worked, just continue. - - (batch_size, rows, cols) = x.shape # and rows <= cols - - x2 = torch.matmul(x, x.permute(0, 2, 1)) / max(rows, cols) - x3 = torch.matmul(x2, x) - - x[:] = x3 - - - - def compute_prod3(x): - # computes matrix-matrix-matrix-matrix-matrix product of batch of matrices x, with reshaping if necessary; - # first divides x by max(num_rows, num_cols)^2 so its a kind of normalized 3rdproduct. - x = x.clone() - compute_prod3_inplace(x) - return x - - + assert x.ndim == 3 + if x.shape[1] <= x.shape[2]: + x2 = torch.matmul(x, x.permute(0, 2, 1)) + return torch.matmul(x2, x) + else: + x2 = torch.matmul(x.permute(0, 2, 1), x) + return torch.matmul(x, x2) def scale_by(x, beta1): @@ -392,17 +257,16 @@ def prod(l): return batch_size, prod(shape[:i]), prod(shape[i:]) -def momentum_step(group, state, grad): - delta = base_step(group, state, grad) # base_step just normalizes overall scale of tensor with a scalar estimate of its rms. - # delta is the normalized gradient; the rms of delta should be around 1. +def cubic_decay_step(group, state, grad): + delta = grad lr = group["lr"] eps = group["eps"] step = state["step"] beta1 = min(group["beta1"], 1. - 1. / (10. + 0.2 * step)) + beta2 = group["beta2"] direct = group["direct"] - linear_decay_scale = group["linear_decay_scale"] - excess_scale = group["excess_scale"] + cubic_decay_scale = group["cubic_decay_scale"] min_scale, max_scale = group["scale_limits"] @@ -414,8 +278,6 @@ def momentum_step(group, state, grad): stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) state["delta"] = stored_delta - - def min_sum_scale(x, y): # returns the scale alpha such that (x + alpha y) is minimized. x and y have # the same shape and the shape of alpha is (x.shape[0], 1, 1, ...). @@ -428,96 +290,102 @@ def min_sum_scale(x, y): # alpha = xy / yy return -xy / (yy + eps) - if delta.ndim >= 3 and delta.numel() != delta.shape[0] * max(delta.shape[1:]): - d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) + d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) + assert d.untyped_storage() is stored_delta.untyped_storage() + (batch_size, rows, cols) = d.shape - if "delta2_buffer0" not in state: - state["delta2_buffer0"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) - state["delta2_buffer1"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) - state["delta_scale_buffer"] = torch.ones(d.shape[0], 1, 1, device=d.device, dtype=d.dtype) + if "row_stats" not in state: + state["row_stats"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) + state["direct_row_stats"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) + state["col_stats"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) + state["direct_col_stats"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) - delta2_buffer0 = state["delta2_buffer0"] - delta2_buffer1 = state["delta2_buffer1"] - delta_scale_buffer = state["delta_scale_buffer"] - d.add_(delta.reshape(*d.shape) * delta_scale_buffer) + row_stats = state["row_stats"] + col_stats = state["col_stats"] + direct_row_stats = state["direct_row_stats"] + direct_col_stats = state["direct_col_stats"] + delta = delta.reshape(*d.shape) - d.mul_(linear_decay_scale * beta1 + (1 - linear_decay_scale)) + d.add_(delta, alpha=(1 - beta1)) + d.mul_(beta1) + d2 = d ** 2 - # we'll scale both before and after the cubing. - # the lines where we divide by sqrt of the mean are so we don't double - # count the scalar component of this. - factor0 = (delta2_buffer0 + eps).sqrt() - factor0_mean = factor0.mean(dim=1, keepdim=True) - factor0 = factor0 / factor0_mean - factor1 = (delta2_buffer1 + eps).sqrt() - row_col_scale = 1. / (factor0 * factor1) + # we'll scale both before and after the cubing. + # the lines where we divide by sqrt of the mean are so we don't double + # count the scalar component of this. + row_scale = (row_stats + eps).sqrt() + col_scale = (col_stats + eps).sqrt() + row_col_scale = row_scale * col_scale - d_scaled = d * row_col_scale - x3 = d_scaled * (((1 - beta1**2)**0.5) / excess_scale) # normalized-scale version of stored_delta, times + d_norm1 = d / row_col_scale # this is the first of two steps of normalizing by these stats. - compute_prod3_inplace(x3) # actually computes 3rd power of its arg divided by max(rows, cols)**2 - # the factor of 0.5 says we only want to go, at most, half the way to the point which - # would give us the minimum 'x'. this is to prevent the largest eigs overshooting - # and having the direction change sign, in a situation where we are not dominated by - # the largest singular value; or to prevent the largest singular value from going to - # zero if it does dominate. - alpha = (0.5 * min_sum_scale(d_scaled, x3)).clamp(min=-1) - # we divide x3 by row_col_scale to "un-normalize". - d.add_(x3 * alpha / row_col_scale) + d_norm1_meansq = (d_norm1 ** 2).mean(dim=(1, 2), keepdim=True) + eps - if random.random() < 0.0001: - rel_scale = (d ** 2).mean().sqrt() / ((1 - beta1**2)**-0.5) - logging.info(f"shape={stored_delta.shape}, rel_scale = {rel_scale.item()}") + d_norm1_scaled = d_norm1 * (d_norm1_meansq * max(rows, cols)) ** (-1/3) + prod3 = compute_prod3(d_norm1_scaled) - if not stored_delta.untyped_storage() is d.untyped_storage(): - stored_delta[:] = d.reshape(*stored_delta.shape) - beta = beta1 # use this beta for row/col scales - d = d * row_col_scale # half-normalized d - assumed_scale = ((1 - beta1**2)**-0.5) # assumed scale of d - d2 = (d / assumed_scale) ** 2 + alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_scale) + # we multiply prod3 by row_col_scale to "un-normalize". + # In the normal case where we're not limited by stability-of-update-concerns, + # the next line of code is equivalent to: + # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_scale) + d.add_((prod3 * row_col_scale) * alpha) - delta_scale_buffer.add_((1 - d2.mean(dim=(1, 2), keepdim=True)).sign(), alpha=0.01) # infinite gain to make factor0_mean equal to 1 + d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. - if random.random() < 0.0001: - logging.info(f"shape={stored_delta.shape}, mean of normalized d2 is {d2.mean().item()}") - delta2_buffer0.mul_(beta).add_(d2.mean(dim=2, keepdim=True), alpha=(1 - beta)) - delta2_buffer1.mul_(beta).add_(d2.mean(dim=1, keepdim=True), alpha=(1 - beta)) - d = d * row_col_scale # fully-normalized d + d_norm1_sq = d_norm1 ** 2 - stored_delta = d.reshape(*stored_delta.shape) # note: permanent buffer is not updated. + # first update row_stats. + row_stats.mul_(beta2).add_((d_norm1 ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + # d_norm1b means we've doing the second normalization but only by rows so far. + d_norm1b = d_norm1 / (row_stats + eps).sqrt() - def s(x): - if x.ndim <= 1: - return x.to('cpu') - else: - return (x ** 2).mean(dim=list(range(1, x.ndim))).sqrt().to('cpu') + col_stats.mul_(beta2).add_((d_norm1b ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + + d_norm2 = d_norm1b / (col_stats + eps).sqrt() + + # do "immediate" normalization of total norm to make the overall scale of the update what + # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) # assumed scale of d if stats were i.i.d. + + d_norm3 = d_norm2 * (assumed_scale / ((d_norm2 ** 2).mean(dim=(1, 2), keepdim=True) + eps).sqrt()) + + moving_update = d_norm3 + + def s(x): + if x.ndim <= 1: + return x.to('cpu') + else: + return (x ** 2).mean(dim=list(range(1, x.ndim))).sqrt().to('cpu') + + #if step < 100: + # assert torch.all(stored_delta - stored_delta == 0.0), (step, s(stored_delta), s(delta), delta.shape,s(d), s(x3), s(delta2_buffer0), s(delta2_buffer1), s(factor0), s(factor1), s(row_col_scale), s(d2), s(row_col_scale)) - if step < 100: - assert torch.all(stored_delta - stored_delta == 0.0), (step, s(stored_delta), s(delta), delta.shape,s(d), s(x3), s(delta2_buffer0), s(delta2_buffer1), s(factor0), s(factor1), s(row_col_scale), s(d2), s(row_col_scale)) - else: - stored_delta.add_(delta) - stored_delta.mul_(beta1) + # row/col normalization of direct/bypass gradient "delta". + direct_row_stats.mul_(beta2).add_((delta ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + delta = delta / (direct_row_stats + eps).sqrt() + direct_col_stats.mul_(beta2).add_((delta ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + delta = delta / (direct_col_stats + eps).sqrt() - ans = (((1-direct) * (1-beta1)) * stored_delta) + (direct * delta) - # OK, now divide ans by its rms so it has unit rms - norm_ans = False - if norm_ans: - dims = list(range(1, ans.ndim)) - ans = ans / ((ans ** 2).mean(dim=dims, keepdim=True) + eps).sqrt() - return -lr * ans + ans = (-lr * (1-direct)) * moving_update + (-lr * direct) * delta + return ans.reshape(*grad.shape) def scaling_step(group, param, state, grad): - delta = momentum_step(group, state, grad) - # delta is the normalized gradient; the rms of delta should be around 1. lr = group["lr"] wd = group["wd"] + if grad.ndim >= 3 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): + delta = cubic_decay_step(group, state, grad) + else: + # biases and similar-shaped tensors + delta = adam_step(group, state, grad) + try: scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] @@ -547,23 +415,34 @@ def scaling_step(group, param, state, grad): return param * delta_scale + scale * delta -def basic_momentum_step(group, state, grad, lr, beta): - delta = base_step(group, state, grad) - +def adam_step(group, state, grad): + lr = group["lr"] step = state["step"] + eps = group["eps"] + # just hardcode these. we only use this code for biases and scalars. + beta1 = 0.98 + beta2 = 0.98 + try: - stored_delta = state["delta"] + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] except KeyError as e: assert step < 2 - # scalar. use conventional momentum. - stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) - state["delta"] = stored_delta + exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + bias_correction2 = 1 - beta2 ** (step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = (exp_avg_sq + eps).sqrt() - stored_delta.add_(delta) - stored_delta.mul_(beta) + return -lr * (exp_avg / denom) - delta = (-lr * (1 - beta)) * stored_delta - return delta @@ -604,13 +483,12 @@ def __init__( self, params, lr=1e-03, - beta1=0.995, - direct=0.05, # scale on bypass of momentum (beta1) - linear_decay_scale=0.2, - excess_scale=3.0, + beta1=0.998, + direct=0.1, # scale on bypass of momentum (beta1) + cubic_decay_scale=0.05, beta2=0.98, wd=10, - eps=1.0e-08, + eps=1.0e-16, scale_limits=(0.5, 2.0), ): @@ -618,8 +496,7 @@ def __init__( lr=lr, beta1=beta1, direct=direct, - linear_decay_scale=linear_decay_scale, - excess_scale=excess_scale, + cubic_decay_scale=cubic_decay_scale, beta2=beta2, eps=eps, wd=wd, @@ -771,7 +648,7 @@ def step(self, closure=None): cur_step = 0 if p.numel() == p.shape[0]: - p += basic_momentum_step(group, state, grad, group["lr"], group["beta1"]) + p += adam_step(group, state, grad) else: p += scaling_step(group, p.detach(), state, grad) @@ -1016,21 +893,19 @@ def __init__( self, params, lr=1e-03, - beta1=0.995, - direct=0.05, # scale on bypass of momentum (beta1) - linear_decay_scale=0.2, - excess_scale=3.0, + beta1=0.998, + direct=0.1, # scale on bypass of momentum (beta1) + cubic_decay_scale=0.05, beta2=0.98, wd=10, - eps=1.0e-08, + eps=1.0e-16, scale_limits=(0.5, 2.0), ): defaults = dict( lr=lr, beta1=beta1, direct=direct, - linear_decay_scale=linear_decay_scale, - excess_scale=excess_scale, + cubic_decay_scale=cubic_decay_scale, beta2=beta2, eps=eps, wd=wd, @@ -1073,7 +948,7 @@ def u(x): return x.unsqueeze(0) if p.numel() == 1: - p += basic_momentum_step(group, state, grad, group["lr"], group["beta1"]) + p += adam_step(group, state, grad) else: p += scaling_step(group, u(p.detach()), state, u(grad))[0] @@ -1129,9 +1004,9 @@ def _test_transformed_adam(hidden_dim: int): lr = 0.0006 if test == 0: - optim = TransformedAdam(m.named_parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.99) + optim = TransformedAdam(m.named_parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.998, direct=0.05, cubic_decay_scale=0.005) elif test == 1: - optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.99) + optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.998, direct=0.05, cubic_decay_scale=0.005) num_epochs = 180 From 92875aace3460bbb117b88a5d2be121653448b5f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 18:58:52 +0800 Subject: [PATCH 0953/1191] Change some defaults in optim.py --- egs/librispeech/ASR/zipformer/optim.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index b77bc03b0d..3b56e34985 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -484,10 +484,10 @@ def __init__( params, lr=1e-03, beta1=0.998, - direct=0.1, # scale on bypass of momentum (beta1) - cubic_decay_scale=0.05, + direct=0.05, # scale on bypass of momentum (beta1) + cubic_decay_scale=0.005, beta2=0.98, - wd=10, + wd=25, eps=1.0e-16, scale_limits=(0.5, 2.0), ): @@ -894,10 +894,10 @@ def __init__( params, lr=1e-03, beta1=0.998, - direct=0.1, # scale on bypass of momentum (beta1) - cubic_decay_scale=0.05, + direct=0.05, # scale on bypass of momentum (beta1) + cubic_decay_scale=0.005, beta2=0.98, - wd=10, + wd=25, eps=1.0e-16, scale_limits=(0.5, 2.0), ): @@ -1004,9 +1004,9 @@ def _test_transformed_adam(hidden_dim: int): lr = 0.0006 if test == 0: - optim = TransformedAdam(m.named_parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.998, direct=0.05, cubic_decay_scale=0.005) + optim = TransformedAdam(m.named_parameters(), lr=lr, wd=24) elif test == 1: - optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=24, eps=1.0e-20, beta1=0.998, direct=0.05, cubic_decay_scale=0.005) + optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=24) num_epochs = 180 From 22e10c1cde59584163e07490f0b764dd16018113 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 9 Mar 2026 19:35:54 +0800 Subject: [PATCH 0954/1191] Express cubic_decay_scale relative to beta1. --- egs/librispeech/ASR/zipformer/optim.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3b56e34985..44f6d0e5bd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -327,7 +327,7 @@ def min_sum_scale(x, y): prod3 = compute_prod3(d_norm1_scaled) - alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_scale) + alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_scale*(1-beta1)) # we multiply prod3 by row_col_scale to "un-normalize". # In the normal case where we're not limited by stability-of-update-concerns, # the next line of code is equivalent to: @@ -485,7 +485,7 @@ def __init__( lr=1e-03, beta1=0.998, direct=0.05, # scale on bypass of momentum (beta1) - cubic_decay_scale=0.005, + cubic_decay_scale=2.5, beta2=0.98, wd=25, eps=1.0e-16, @@ -895,7 +895,7 @@ def __init__( lr=1e-03, beta1=0.998, direct=0.05, # scale on bypass of momentum (beta1) - cubic_decay_scale=0.005, + cubic_decay_scale=2.5, beta2=0.98, wd=25, eps=1.0e-16, @@ -1010,18 +1010,11 @@ def _test_transformed_adam(hidden_dim: int): num_epochs = 180 - warmup_steps = 0 - # hardcode batches per epoch for now. total_steps = num_epochs - warmup_start = 0.5 def lr_lambda(current_step): - if current_step < warmup_steps: - # Linear warm-up - return warmup_start + (1.0 - warmup_start) * current_step / warmup_steps - else: - # Cosine annealing - progress = (current_step - warmup_steps) / (total_steps - warmup_steps) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + # Cosine annealing + progress = current_step / total_steps + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) scheduler = LambdaLR(optim, lr_lambda) From 725226a7972dbfe99dc06ee9996c42f7f537d13f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Mar 2026 10:38:59 +0800 Subject: [PATCH 0955/1191] Refactor the update to replace cubic_decay_scale with cubic_decay_proportion; set beta1=0.999, direct=0.0 (must tune) --- egs/librispeech/ASR/zipformer/optim.py | 82 +++++++++++++++++--------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 44f6d0e5bd..3743c36304 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -127,14 +127,24 @@ def batched_params(self, param_group, group_params_names): def compute_prod3(x): - assert x.ndim == 3 - if x.shape[1] <= x.shape[2]: - x2 = torch.matmul(x, x.permute(0, 2, 1)) + assert x.ndim >= 2 + if x.shape[-2] <= x.shape[-1]: + x2 = torch.matmul(x, x.transpose(-2, -1)) return torch.matmul(x2, x) else: - x2 = torch.matmul(x.permute(0, 2, 1), x) + x2 = torch.matmul(x.transpose(-2, -1), x) return torch.matmul(x, x2) +def compute_scaled_prod3(x): + # computes 3-way matrix power x^3 (x is treated as a batch of matrices) with a scaling such that (for each + # matrix in the batch) if all the singular values of the matrix are the same, the result will be identical to the input. + + rows, cols = x.shape[-2], x.shape[-1] + + eps = 1.0e-40 + x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps + x = x * (x_meansq * max(rows, cols)) ** (-1/3) + return compute_prod3(x) def scale_by(x, beta1): # This is similar in efffect @@ -266,7 +276,8 @@ def cubic_decay_step(group, state, grad): beta1 = min(group["beta1"], 1. - 1. / (10. + 0.2 * step)) beta2 = group["beta2"] direct = group["direct"] - cubic_decay_scale = group["cubic_decay_scale"] + cubic_decay_proportion = group["cubic_decay_proportion"] + linear_decay_proportion = 1. - cubic_decay_proportion min_scale, max_scale = group["scale_limits"] @@ -307,8 +318,8 @@ def min_sum_scale(x, y): delta = delta.reshape(*d.shape) - d.add_(delta, alpha=(1 - beta1)) - d.mul_(beta1) + d.add_(delta) # the scale used here doesn't matter as it all gets normalized. + d.mul_(1 - (linear_decay_proportion * (1 - beta1))) d2 = d ** 2 @@ -321,17 +332,13 @@ def min_sum_scale(x, y): d_norm1 = d / row_col_scale # this is the first of two steps of normalizing by these stats. - d_norm1_meansq = (d_norm1 ** 2).mean(dim=(1, 2), keepdim=True) + eps - - d_norm1_scaled = d_norm1 * (d_norm1_meansq * max(rows, cols)) ** (-1/3) - prod3 = compute_prod3(d_norm1_scaled) - + prod3 = compute_scaled_prod3(d_norm1) - alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_scale*(1-beta1)) + alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) # we multiply prod3 by row_col_scale to "un-normalize". # In the normal case where we're not limited by stability-of-update-concerns, # the next line of code is equivalent to: - # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_scale) + # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_proportion) d.add_((prod3 * row_col_scale) * alpha) d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. @@ -350,21 +357,25 @@ def min_sum_scale(x, y): # do "immediate" normalization of total norm to make the overall scale of the update what # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. - assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) # assumed scale of d if stats were i.i.d. + # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style + # accumulator with beta equal to beta1. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) d_norm3 = d_norm2 * (assumed_scale / ((d_norm2 ** 2).mean(dim=(1, 2), keepdim=True) + eps).sqrt()) moving_update = d_norm3 - def s(x): - if x.ndim <= 1: - return x.to('cpu') - else: - return (x ** 2).mean(dim=list(range(1, x.ndim))).sqrt().to('cpu') + #def s(x): + # if x.ndim <= 1: + # return x.to('cpu') + # else: + # return (x ** 2).mean(dim=list(range(1, x.ndim))).sqrt().to('cpu') #if step < 100: # assert torch.all(stored_delta - stored_delta == 0.0), (step, s(stored_delta), s(delta), delta.shape,s(d), s(x3), s(delta2_buffer0), s(delta2_buffer1), s(factor0), s(factor1), s(row_col_scale), s(d2), s(row_col_scale)) + if direct == 0.0: + return -lr * moving_update.reshape(*grad.shape) # row/col normalization of direct/bypass gradient "delta". direct_row_stats.mul_(beta2).add_((delta ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) @@ -483,9 +494,9 @@ def __init__( self, params, lr=1e-03, - beta1=0.998, - direct=0.05, # scale on bypass of momentum (beta1) - cubic_decay_scale=2.5, + beta1=0.999, + direct=0.0, # scale on bypass of momentum (beta1) + cubic_decay_proportion=0.8, beta2=0.98, wd=25, eps=1.0e-16, @@ -496,7 +507,7 @@ def __init__( lr=lr, beta1=beta1, direct=direct, - cubic_decay_scale=cubic_decay_scale, + cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, wd=wd, @@ -893,9 +904,9 @@ def __init__( self, params, lr=1e-03, - beta1=0.998, - direct=0.05, # scale on bypass of momentum (beta1) - cubic_decay_scale=2.5, + beta1=0.999, + direct=0.0, # scale on bypass of momentum (beta1) + cubic_decay_proportion=0.8, beta2=0.98, wd=25, eps=1.0e-16, @@ -905,7 +916,7 @@ def __init__( lr=lr, beta1=beta1, direct=direct, - cubic_decay_scale=cubic_decay_scale, + cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, wd=wd, @@ -1180,6 +1191,20 @@ def lr_lambda(current_step): logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") + +def _test_compute_scaled_prod3(): + x = torch.randn(3, 16, 32) + _U, _S, V = torch.linalg.svd(x, full_matrices=False) + W = V * torch.randn(3, 1, 1) + # so now all the singular values of x will be identical (but arbitrary) + + X = compute_scaled_prod3(W) + #print("X = ", X[0]) + #print("W = ", W[0]) + assert torch.allclose(W, X, atol=1.0e-03) + # but the result won't be identical to the input if the singular values are not all identical. + assert not torch.allclose(x, compute_scaled_prod3(x), atol=1.0e-03) + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) @@ -1198,4 +1223,5 @@ def lr_lambda(current_step): hidden_dim = 200 #_test_muon(hidden_dim) + _test_compute_scaled_prod3() _test_transformed_adam(hidden_dim) From 4950c4529efaa023da669791898cd1ae3f31a781 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Mar 2026 10:49:29 +0800 Subject: [PATCH 0956/1191] Change how beta2 is set, only affects startup. --- egs/librispeech/ASR/zipformer/optim.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 3743c36304..cf4c8773cd 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -273,8 +273,9 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta1 = min(group["beta1"], 1. - 1. / (10. + 0.2 * step)) - beta2 = group["beta2"] + beta_ceil = 1. - 1. / (10. + 0.2 * step) + beta1 = min(group["beta1"], beta_ceil) + beta2 = min(group["beta2"], beta_ceil) direct = group["direct"] cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion From 6bd2c55695cd587ae50a798c83174bfa4cd5bacd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Mar 2026 13:41:02 +0800 Subject: [PATCH 0957/1191] Increase direct=0.0 to direct=0.1 --- egs/librispeech/ASR/zapformer/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 0106854086..c083f26013 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1362,6 +1362,7 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, + direct=0.1, wd=25, scale_limits=(1.0, 4.0), ) From fbf23543e81033c29067fd86a2649aba90f11774 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Mar 2026 23:17:23 +0800 Subject: [PATCH 0958/1191] Reduce weight decay from 25 to 18 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index c083f26013..6baa39defe 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1363,7 +1363,7 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, direct=0.1, - wd=25, + wd=18, scale_limits=(1.0, 4.0), ) From a10dbc6f17ac9595c5c5ea8bcaf364097eb3553d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Mar 2026 23:38:57 +0800 Subject: [PATCH 0959/1191] Reduce cubic_decay_proportion from 0.8 to 0.75. --- egs/librispeech/ASR/zapformer/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6baa39defe..b72236a6dc 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1363,6 +1363,7 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, direct=0.1, + cubic_decay_proportion=0.75, wd=18, scale_limits=(1.0, 4.0), ) From 4a9ecba3a8ab5cfd0c5be5783cab4b1c1c03e92f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 11 Mar 2026 11:15:54 +0800 Subject: [PATCH 0960/1191] Decrease beta1 from .999 to .998. --- egs/librispeech/ASR/zapformer/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b72236a6dc..866aea9863 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1365,6 +1365,7 @@ def run(rank, world_size, args): direct=0.1, cubic_decay_proportion=0.75, wd=18, + beta1=0.998, scale_limits=(1.0, 4.0), ) From 027132ab0be19af208219fbd9a309395c415239a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Feb 2026 12:11:08 +0800 Subject: [PATCH 0961/1191] Move self-attention weights input to after ff1. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8b7e7c6012..b874bd710c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -576,12 +576,10 @@ def forward( 2. * aux_loss_scale, mask=src_key_padding_mask), None) - src_pre_ff1 = src src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - # may try changing src_pre_ff1 to src or vice versa. - src = src + self.self_attn(src_pre_ff1, src, attn_mask=attn_mask, + src = src + self.self_attn(src, src, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) From 6f7999e3449011b9a5bf05ab0b9f8c7e4f96eef2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 11 Mar 2026 17:41:06 +0800 Subject: [PATCH 0962/1191] Separate streaming and non-streaming versions of SequenceNorm and remove ballast from non-streaming version. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zipformer/scaling.py | 202 +++++++++++++++------ egs/librispeech/ASR/zipformer/zipformer.py | 6 +- 3 files changed, 150 insertions(+), 60 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 06b00827f1..bd6a8a88ee 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -174,7 +174,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,16,8", + default="6,8,14,8", help="Number of zipformer encoder layers per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 954d4665ea..8c0580c824 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -333,53 +333,61 @@ def backward(ctx, x_grad, *args): -# all arg tensors are scalars -def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor, causal: bool, mask: Optional[Tensor]): +# all arg tensors are scalars. +def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): + stats = (x ** 2).mean(dim=2, keepdim=True) + T = x.shape[0] # time + if mask is None: + stats = stats.sum(dim=0) + lengths = T + else: + mask = (~mask).to(torch.float).t().unsqueeze(-1) + stats = stats * mask + stats = stats.sum(dim=0) + lengths = mask.sum(dim=0) + + scales = (lengths / stats).sqrt() + assert scales.shape == (x.shape[1], 1) + return x * ((scale * scales) + offset) + +# all arg tensors are scalars. +# mask only used in non-causal mode; ballast_rms and ballast_frames only used in causal mode. +def _causal_sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor): stats = (x ** 2).mean(dim=2, keepdim=True) + + # no need for mask in causal mode. # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so # make absolutely sure using abs(). ballast_frames = 100.0 * ballast_frames.abs() ballast = ballast_frames * (ballast_rms ** 2) T = x.shape[0] # time - if causal: - # no need for mask in causal mode. - stats = stats.cumsum(dim=0) + ballast - lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] - else: - if mask is None: - # no need for mask in causal mode. - stats = stats.sum(dim=0) + ballast - lengths = ballast_frames + T - else: - mask = (~mask).to(torch.float).t().unsqueeze(-1) - stats = stats * mask - stats = stats.sum(dim=0) + ballast - lengths = ballast_frames + mask.sum(dim=0) + stats = stats.cumsum(dim=0) + ballast + lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] - scales = (lengths / stats).sqrt() # (T, batch_size, 1) if causal else (batch_size 1) - assert scales.shape == (T, x.shape[1], 1) if causal else (x.shape[1], 1) + scales = (lengths / stats).sqrt() + assert scales.shape == (T, x.shape[1], 1) return x * ((scale * scales) + offset) # all arg tensors are scalars -def _sequence_norm_streaming( - x: Tensor, - offset: Tensor, - scale: Tensor, +def _causal_sequence_norm_streaming( + x: Tensor, + offset: Tensor, + scale: Tensor, cached_stats_sum: Tensor, cached_len: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """Streaming inference forward for _sequence_norm. We assume that ballast_frames and ballast_rms - are already included in cached_stats_sum and cached_len. + are already included in cached_stats_sum and cached_len. - Args: + Args: x: (seq_len, batch_size, channels) offset: scalar scale: scalar cached_stats_sum: (batch_size,) cached_len: (batch_size,) - + Returns: - normalized x, (seq_len, batch_size, channels) - updated cached_stats_sum, (batch_size,) @@ -391,17 +399,17 @@ def _sequence_norm_streaming( stats = stats.cumsum(dim=0) + cached_stats_sum.unsqueeze(-1) lengths = cached_len[:, None] + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] - + # update cached_stats_sum and cached_len for the next chunk - cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) + cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) cached_len = cached_len + T - scales = (lengths / stats).sqrt() # (T, batch_size, 1) - assert scales.shape == (T, x.shape[1], 1) + scales = (lengths / stats).sqrt() # (T, batch_size, 1) + assert scales.shape == (T, x.shape[1], 1) return x * ((scale * scales) + offset), cached_stats_sum, cached_len -class SequenceNormFunction(torch.autograd.Function): +class CausalSequenceNormFunction(torch.autograd.Function): @staticmethod def forward( ctx, @@ -410,14 +418,10 @@ def forward( scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor, - causal: bool, - mask: Optional[Tensor], ) -> Tensor: ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) - ctx.causal = causal - ctx.mask = mask - return _sequence_norm(x, offset, scale, ballast_rms, ballast_frames, causal, mask) + return _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) @staticmethod @@ -433,46 +437,85 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: ballast_frames = ballast_frames.to(torch.float32).detach().requires_grad_() with torch.enable_grad(): - ans = _sequence_norm(x, offset, scale, ballast_rms, ballast_frames, ctx.causal, ctx.mask) + ans = _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) ans.backward(gradient=ans_grad.to(torch.float32)) def c(x): # this is to replace infinities that might be thrown up - # in autocast mode. + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). return x.clamp_(min=-30000.0, max=30000.0) - return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad), None, None + return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad) + +class SequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + mask: Optional[Tensor], + ) -> Tensor: + ctx.save_for_backward(x, offset, scale) + ctx.mask = mask + + return _sequence_norm(x, offset, scale, mask) -class SequenceNorm(torch.nn.Module): + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, offset, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _sequence_norm(x, offset, scale, ctx.mask) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x if x is None else x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), None + + +class CausalSequenceNorm(torch.nn.Module): """ This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence - as well as the channels; and a padding mask is used for irregular length sequences (actually, - the mask is applied multiplicatively as well.) + up to the current point as well as the channels, with some padding of the stats with "default values" + determined by ballast_frames, ballast_rms for robustness near the beginning of the sequence. - There is also a learnable scalar scale and a learnable "offset" value. + There is also a learnable scalar scale, multiplicatively applied to the output, and a learnable + "offset" value that acts multiplicatively on the input without taking into account the rms values. """ def __init__( self, - causal: bool, ) -> None: - super(SequenceNorm, self).__init__() + super().__init__() self.scale = nn.Parameter(torch.tensor(0.5)) self.offset = nn.Parameter(torch.tensor(0.0001)) + # ballast_mean: assumed rms value of ballast frames used to pad stats self.ballast_rms = nn.Parameter(torch.tensor(0.1)) # ballast_frames: number of ballast frames, in hundreds (will be multiplied by 100) self.ballast_frames = nn.Parameter(torch.tensor(0.05)) # number of ballast frames, will be multiplied by 100 - self.causal = causal self.name = None - def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + def forward(self, x: Tensor, _mask: Optional[Tensor] = None) -> Tensor: # x: (seq, batch, channel) - # mask: bool, (batch_size, seq_len) - # Note: mask is ignored in causal mode. - + # The mask is ignored, it is allowed only for consistency of interface with SequenceNorm. if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames, self.causal, mask) + return _causal_sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames) scale = limit_param_value( self.scale, min=0.05, max=2.0, training=self.training) @@ -486,8 +529,8 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: ballast_frames = limit_param_value( self.ballast_frames, min=0.0, max=5.0, training=self.training) # max of 5.0 would be 500 frames - ans = SequenceNormFunction.apply( - x, offset, scale, ballast_rms, ballast_frames, self.causal, mask, + ans = CausalSequenceNormFunction.apply( + x, offset, scale, ballast_rms, ballast_frames, ) if random.random() < 0.002: @@ -499,7 +542,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: @torch.jit.export def get_init_cache(self, batch_size: int): - """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. + """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. """ # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so # make absolutely sure using abs(). @@ -512,17 +555,60 @@ def get_init_cache(self, batch_size: int): return cached_stats_sum, cached_len def streaming_forward( - self, - x: Tensor, - cached_stats_sum: Tensor, + self, + x: Tensor, + cached_stats_sum: Tensor, cached_len: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: - - x, cached_stats_sum, cached_len = _sequence_norm_streaming( + + x, cached_stats_sum, cached_len = _causal_sequence_norm_streaming( x, self.offset, self.scale, cached_stats_sum, cached_len) return x, cached_stats_sum, cached_len +class SequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + as well as the channels; and a padding mask is used for irregular length sequences (actually, + the mask is applied multiplicatively as well.) + + There is also a learnable scalar scale and a learnable "offset" value. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + self.name = None + + def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + # x: (seq, batch, channel) + # mask: bool, (batch_size, seq_len) + # Note: mask is ignored in causal mode. + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _sequence_norm(x, self.offset, self.scale, mask) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ans = SequenceNormFunction.apply( + x, offset, scale, mask, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") + + return ans + + + # assume layout: (time, batch, channel) def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 3db8fde8df..0bc5466ac2 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -47,6 +47,10 @@ with_loss, ) +try: + from scaling import CausalSequenceNorm +except: + pass from torch import Tensor, nn @@ -554,7 +558,7 @@ def __init__( self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) - self.norm = SequenceNorm(causal=causal) + self.norm = CausalSequenceNorm() if causal else SequenceNorm() def forward( self, From 1f908bfac2ae4ce92c70b6df096bb8d004ffd915 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 11 Mar 2026 23:08:16 +0800 Subject: [PATCH 0963/1191] Make min_factor simply added linearly (not affect progress) and increase .1->.15. --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 54fb799643..c03cfed60b 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -120,18 +120,13 @@ def print_lr(self, is_verbose, group, lr): class CosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, - min_factor: float = 0.1, + min_factor: float = 0.15, **kwargs): super().__init__(*args, **kwargs) - # min_factor has two roles: it acts as a minimum relative learning rate - # (linearly applied, not with max); and it makes the final learning rate - # constant for the final (min_factor) of the schedule, compressing the - # cosine decay into the first (1 - min_factor) of the schedule. self.min_factor = min_factor def get_lr(self): progress = self.get_progress() - progress = min(1.0, progress / (1.0 - self.min_factor)) # clamp progress at 1.0 for final min_factor of schedule. factor = 0.5 * (1.0 + math.cos(math.pi * progress)) factor = self.min_factor + (1. - self.min_factor) * factor # apply min_factor linearly. return [x * factor for x in self.base_lrs] From a217e3d0f49074cb105084acaf47056fa9098fa9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 12 Mar 2026 12:22:36 +0800 Subject: [PATCH 0964/1191] Increase cubic_decay_proportion=0.75 back to cubic_decay_proportion=0.8 and direct=0.1 to direct=0.15 --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 866aea9863..5bd71a5bae 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1362,8 +1362,8 @@ def run(rank, world_size, args): optimizer = TransformedAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.1, - cubic_decay_proportion=0.75, + direct=0.15, + cubic_decay_proportion=0.8, wd=18, beta1=0.998, scale_limits=(1.0, 4.0), From e274669e0d44725ddd8ca1a80b8d390fe6219f91 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 12 Mar 2026 12:31:34 +0800 Subject: [PATCH 0965/1191] Increase cubic_decay_proportion from .8 to .85 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5bd71a5bae..44aefee59c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1363,7 +1363,7 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, direct=0.15, - cubic_decay_proportion=0.8, + cubic_decay_proportion=0.85, wd=18, beta1=0.998, scale_limits=(1.0, 4.0), From af732159ab755e1ca8331603de5c038df7f05366 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 12 Mar 2026 12:51:57 +0800 Subject: [PATCH 0966/1191] Implement min_factor and max_factor in cosine scheduler via changing range of progress, use 0.95,0.05 --- .../ASR/zapformer/combined_scheduler.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index c03cfed60b..d86b8d4a01 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -120,15 +120,24 @@ def print_lr(self, is_verbose, group, lr): class CosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, - min_factor: float = 0.15, + max_factor: float = 0.95, # it will start the cosine schedule from where it's this value, but renormalize so initial factor is 1. + min_factor: float = 0.05, # it will end the cosine schedule at where it's this value **kwargs): super().__init__(*args, **kwargs) - self.min_factor = min_factor + self.max_factor = max_factor + def factor_to_progress(factor): + # inverse function of: factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos = 2.0 * factor - 1.0 + return math.acos(cos) / math.pi + self.initial_progress = factor_to_progress(max_factor) + self.final_progress = factor_to_progress(min_factor) def get_lr(self): progress = self.get_progress() + # map progress in [0..1] to a tighter range like [0.15..0.85] + progress = self.initial_progress + (self.final_progress - self.initial_progress) * progress factor = 0.5 * (1.0 + math.cos(math.pi * progress)) - factor = self.min_factor + (1. - self.min_factor) * factor # apply min_factor linearly. + factor = factor / self.max_factor # make it so the initial factor is 1.0 despite limiting range of progress return [x * factor for x in self.base_lrs] From d343e0a09ec0f02fa81cd688c6742c611994ab55 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Mar 2026 11:37:43 +0800 Subject: [PATCH 0967/1191] Revert "Move self-attention weights input to after ff1." This reverts commit 4da937c0f9eef0328f0fca13da836e48a51a5e58. --- egs/librispeech/ASR/zipformer/zipformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b874bd710c..8b7e7c6012 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -576,10 +576,12 @@ def forward( 2. * aux_loss_scale, mask=src_key_padding_mask), None) + src_pre_ff1 = src src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - src = src + self.self_attn(src, src, attn_mask=attn_mask, + # may try changing src_pre_ff1 to src or vice versa. + src = src + self.self_attn(src_pre_ff1, src, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) From f0410f6c02ab053df9d9ab058549e2e99e7eb63d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Mar 2026 12:04:47 +0800 Subject: [PATCH 0968/1191] Documentation changes. --- .../ASR/zapformer/combined_scheduler.py | 47 +++++++++++++++++-- egs/librispeech/ASR/zapformer/train.py | 4 +- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index d86b8d4a01..6a91a312cc 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -8,7 +8,26 @@ class CombinedLRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch; it estimates the "progress" for you. + batch and the epoch; it estimates the "progress" for you based on the epoch you are + in and the estimated progress within the epoch based on + the number of steps within the epoch. The interface is as follows;; suppose + you're using CosineLRScheduler that inherits from this (below). batches_per_epoch + is your best guess at how many batches you will have per epoch; if you get this + wrong there will be a discontinuity in the learning rate as you start the second + epoch. + + num_epochs = 20 + scheduler = CosineLRScheduler(optimizer, batches_per_epoch=2512, num_epochs=num_epochs) + for epoch in range(1, num_epochs + 1): + scheduler.set_epoch(epoch) # caution: one-based epoch count + for batch_idx, batch in enumerate(train_dl): + scheduler.set_batch_idx(batch_idx) + + Args: + optimizer: optimizer that we will set the learning rates in; the initial learning rate(s) in + the optimizer is/are the base LRs and we set the LR as a fraction of those. + batches_per_epoch: the estimated number of batches per epoch; use your best guess. + num_epochs: the total number of epochs you will train for """ def __init__(self, optimizer: Optimizer, @@ -71,15 +90,18 @@ def get_lr(self): raise NotImplementedError def set_batch(self, batch: int): + """ Sets the batch index within the epoch, with zero-based counting (not that this matters much).""" # set the within-epoch batch index. self.batch = batch self._set_lrs() def set_epoch(self, epoch: int): + """ Sets the epoch with one-based counting, so the first epoch is 1; the epoch should not exceed the num_epochs used + in the constructor. """ assert epoch > 0 and epoch <= self.num_epochs # Epoch numbers are assumed to be be 1-based indexes. if epoch == self.epoch + 1 and self.batch > 0: - logging.info(f"Overriding batches_per_epoch from {self.batches_per_epoch} to {self.batch} based on observed batch count.") - self.batches_per_epoch = self.batch + logging.info(f"Overriding batches_per_epoch from {self.batches_per_epoch} to {self.batch+1} based on observed batch count.") + self.batches_per_epoch = self.batch + 1 self.epoch = epoch self._set_lrs() @@ -120,15 +142,30 @@ def print_lr(self, is_verbose, group, lr): class CosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, - max_factor: float = 0.95, # it will start the cosine schedule from where it's this value, but renormalize so initial factor is 1. - min_factor: float = 0.05, # it will end the cosine schedule at where it's this value + max_factor: float = 0.95, # it will start the cosine schedule from the point where it would have this, but renormalize so initial factor is 1; + min_factor: float = 0.05, # it will end the cosine schedule at where it's this value divided by max_factor **kwargs): + """ + Cosine learning rate scheduler that inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). + Args: + max_factor, min_factor: The conventional cosine factor goes from 1 to 0 based on the formula: + factor = 0.5 * (1 + cos(pi * progress)). + This scheduler selects the part of that graph from factor=max_factor + to factor=min_factor (imagine cropping the graph by selecting lines + that intersect the y-axis at hose values). It renormalizes so the initial + factor is one by dividing by max_factor; the last factor will actually + be min_factor / max_factor. + """ super().__init__(*args, **kwargs) self.max_factor = max_factor def factor_to_progress(factor): # inverse function of: factor = 0.5 * (1.0 + math.cos(math.pi * progress)) cos = 2.0 * factor - 1.0 return math.acos(cos) / math.pi + + # we'll divide the factors by max_factor in get_lr() after computing the cosine formula, + # so the initial and final factors will actually be 1.0 and min_factor respectively. self.initial_progress = factor_to_progress(max_factor) self.final_progress = factor_to_progress(min_factor) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 44aefee59c..e2e438915e 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1403,7 +1403,7 @@ def lr_lambda(current_step): asr_datamodule = AsrDataModule(args) librispeech = LibriSpeech(args.manifest_dir) - gigaspeech = GigaSpeech(args.manifest_dir) # gigaspeech will only be used if --libri-copies set. this is not a typo! + gigaspeech = GigaSpeech(args.manifest_dir) # gigaspeech will only be used if the --use-giga=True option is set if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() @@ -1419,7 +1419,7 @@ def lr_lambda(current_step): # train_cuts += librispeech.train_other_500_cuts() else: train_cuts = librispeech.train_clean_100_cuts() - train_cuts_len = 100.0 * 3 # 100 hours times 3 for augmentation + train_cuts_len = 100.0 * 3 # 100 hours times 3 for speed augmentation if params.use_giga: if params.full_libri: From 1b322477f5a63fc16badf591e8d973deae5c1390 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Mar 2026 15:27:42 +0800 Subject: [PATCH 0969/1191] Large amount of code cleanup and removal. --- egs/librispeech/ASR/zapformer/model.py | 2 +- egs/librispeech/ASR/zapformer/test_scaling.py | 1 - egs/librispeech/ASR/zapformer/train.py | 1 - egs/librispeech/ASR/zipformer/finetune.py | 4 +- egs/librispeech/ASR/zipformer/optim.py | 13 +- egs/librispeech/ASR/zipformer/scaling.py | 1509 +---------------- egs/librispeech/ASR/zipformer/subsampling.py | 44 +- .../ASR/zipformer/test_subsampling.py | 4 +- egs/librispeech/ASR/zipformer/zipformer.py | 27 +- 9 files changed, 78 insertions(+), 1527 deletions(-) delete mode 120000 egs/librispeech/ASR/zapformer/test_scaling.py diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index cc96669bc4..56e5f896e1 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch import Tensor from encoder_interface import EncoderInterface -from scaling import ScaledLinear, convert_num_channels, PredictLoss +from scaling import ScaledLinear, convert_num_channels from icefall.utils import add_sos, make_pad_mask, time_warp diff --git a/egs/librispeech/ASR/zapformer/test_scaling.py b/egs/librispeech/ASR/zapformer/test_scaling.py deleted file mode 120000 index b776da79a1..0000000000 --- a/egs/librispeech/ASR/zapformer/test_scaling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6808cf8a87..17fcba7d26 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -82,7 +82,6 @@ except: pass from torch.optim.lr_scheduler import LambdaLR -from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor from torch.cuda.amp import GradScaler diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2ff6319140..00f74e3f6c 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -75,7 +75,6 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam -from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor from torch.cuda.amp import GradScaler @@ -608,7 +607,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: encoder_embed = Conv2dSubsampling( in_channels=params.feature_dim, out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + causal=params.causal, ) return encoder_embed @@ -627,7 +626,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index cf4c8773cd..47d2732255 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -996,8 +996,6 @@ def _test_transformed_adam(hidden_dim: int): m = torch.nn.Sequential( Linear(E, hidden_dim), - OrthogonalLinear(hidden_dim, hidden_dim, bias=True, - in_groups=2, group_size=hidden_dim//4), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), torch.nn.PReLU(), @@ -1052,17 +1050,16 @@ def lr_lambda(current_step): avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - norm2 = '%.2e' % (m[1].weight**2).mean().sqrt().item() - norm3 = '%.2e' % (m[3].weight**2).mean().sqrt().item() - norm4 = '%.2e' % (m[5].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[4].weight**2).mean().sqrt().item() bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() - bias_norm2 = '%.2e' % (m[3].bias**2).mean().sqrt().item() - bias_norm3 = '%.2e' % (m[5].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[2].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[4].bias**2).mean().sqrt().item() lr = scheduler.get_last_lr()[0] logging.info( - f"Test {test}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + f"Test {test}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" ) loss.log().backward() optim.step() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 8c0580c824..06cb538627 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -28,6 +28,24 @@ from torch.cuda.amp import custom_bwd, custom_fwd + + +class FloatLike: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class ScheduledFloat: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class SimpleOrthogonalLinear: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class PiecewiseLinear: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class CosineSimilarityLoss: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class PredictLoss: # TODO: remove. this is to solve problems with multiple jobs running. + pass +get_max_similarity = None # TODO: remove. this is to solve problems with multiple jobs running. + + + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) @@ -59,201 +77,6 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: return torch.logaddexp(x, y) -class PiecewiseLinear(object): - """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. - """ - - def __init__(self, *args): - assert len(args) >= 1, len(args) - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [(float(x), float(y)) for x, y in args] - for x, y in self.pairs: - assert isinstance(x, (float, int)), type(x) - assert isinstance(y, (float, int)), type(y) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], ( - i, - self.pairs[i], - self.pairs[i + 1], - ) - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f"PiecewiseLinear({str(self.pairs)[1:-1]})" - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, (float, int)): - return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def max(self, x): - if isinstance(x, (float, int)): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): - """ - Returns (self_mod, p_mod) which are equivalent piecewise linear - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p cross. - """ - assert isinstance(p, PiecewiseLinear), type(p) - - # get sorted x-values without repetition. - x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - return ( - PiecewiseLinear(*zip(x_vals, y_vals1)), - PiecewiseLinear(*zip(x_vals, y_vals2)), - ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specify the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or not in training mode or in - torch.jit scripting mode. - """ - - def __init__(self, *args, default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return ( - f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" - ) - - def __float__(self): - batch_count = self.batch_count - if ( - batch_count is None - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.info( - f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" - ) - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, default=self.default) - else: - return ScheduledFloat( - self.schedule + x.schedule, default=self.default + x.default - ) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), default=self.default) - else: - return ScheduledFloat( - self.schedule.max(x.schedule), default=max(self.default, x.default) - ) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - - - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should @@ -291,47 +114,6 @@ def softmax(x: Tensor, dim: int): return SoftmaxFunction.apply(x, dim) -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - # all arg tensors are scalars. def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): @@ -616,9 +398,6 @@ def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): return x * scales -class GaussNorm: - # this is to prevent errors when running multiple jobs. - pass class RmsNormFunction(torch.autograd.Function): @staticmethod @@ -764,506 +543,7 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: return ans -def predict_loss(x: Tensor, predictor: nn.Module, proj_weight: Tensor, - name: str, - mask: Optional[Tensor]) -> Tensor: - # caution: require input to be (seq, batch, channel) - batch_size = x.shape[1] - - if batch_size % 2 != 0: - assert (not x.requires_grad), "PredictLoss must be used with CR-CTC or similar thing that repeats batch with different augmentation." - return torch.tensor(0.0, device=x.device) - - def gauss_norm(x): - # normalize by gaussianizing on each dimension - values, indexes = x.sort(dim=0) # sort on seq dim - # norm_rank: same shape as x - N = max(2, x.shape[0]) - norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[0], device=x.device, dtype=torch.float) - norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data - norm_rank = norm_rank.reshape(-1, 1, 1) - norm_rank = norm_rank.repeat(1, x.shape[1], x.shape[2]) - x_norm = torch.empty_like(x) - x_norm.scatter_(dim=0, index=indexes, src=norm_rank) - return x_norm - - with torch.no_grad(): - # get the indexes. project, then mean-and-variance-norm, then - # take mx. - x_proj = torch.matmul(x, proj_weight.t()) - with torch.amp.autocast('cuda', enabled=False): - x_proj = gauss_norm(x_proj.to(torch.float)) - - - x_proj = torch.roll(x_proj, batch_size // 2, 1) - x_pred = predictor(x) - - loss = ((x_pred - x_proj) ** 2).mean(dim=-1) - - if random.random() < 0.002: - logging.info(f"predict_loss: name={name}, mean loss before scale = {loss.mean()}") - - if mask is not None: - mask = mask.to(x.dtype) - # note, this mask is True for *non*-masked positions. - # we swap the mask over the two copies of the data; the mask goes with the thing that - # is predicted, not the thing we predict it from.. the idea being that we don't want to ask - # the model to predict masked portions of the time sequence. - mask = torch.roll(mask, batch_size // 2, 1) - loss = loss * mask - - return loss.sum() # we reduce with sum in what we return. - - -class PredictorConvModule(nn.Module): - """A convolution module with a residual connecction, modified from ConvolutionModule in Zipformer2, that is used as - the predictor network in class Predictor. The input format is (seq, batch, channels). - - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - kernel_size: int, - out_channels: int, - ) -> None: - """Construct a ConvolutionModule object.""" - super().__init__() - assert (kernel_size - 1) % 2 == 0 - - self.in_proj = nn.Linear( - channels, - hidden_channels, - ) - - self.bypass_proj = nn.Linear( - channels, - out_channels, - ) - - self.depthwise_conv = nn.Conv1d( - in_channels=hidden_channels, - out_channels=hidden_channels, - groups=hidden_channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - - self.out_proj = ActivationDropoutAndLinear( - hidden_channels, - out_channels, - activation="SwashR", - dropout_p=0.0, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - ) -> Tensor: - bypass = self.bypass_proj(x) - x = self.in_proj(x) # (time, batch, 2*channels) - x = x.permute(1, 2, 0) # (#batch, channels, time). - x = self.depthwise_conv(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - x = bypass + self.out_proj(x) # includes activation. - return x - - - -class PredictLoss(nn.Module): - """ - Adds an auxiliary loss based on predicting the top-1 of randomized codebook - entries. (This relies on the CR-CTC structure of having two differently-masked - copies of the same utterance). Mean and variance normalization is applied prior to getting - the codebook indexes to keep this stable. - """ - def __init__(self, - num_channels: int, - codebook_size: int = 64): - super().__init__() - scale = num_channels ** -0.5 - self.register_buffer('proj_weight', - scale * torch.randn(codebook_size, num_channels), - persistent=True) - num_hidden = max(1024, num_channels) - kernel_size = 7 - self.predictor = PredictorConvModule(num_channels, num_hidden, kernel_size, codebook_size) - - self.name = None # will be set from training code - - def forward(self, - x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - # x is of shape (seq_len, batch_size, num_channels); mask is of shape - # (seq_len, batch_size), with True for *non*-masked positions. - return predict_loss(x, self.predictor, self.proj_weight, - self.name, mask) - - - -class OrthogonalLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, weight: Tensor, name: str, in_groups: int, - out_groups: int, group_size: int, penalty_scale: float): - ctx.save_for_backward(x, weight) - ctx.name = name - ctx.out_groups = out_groups - ctx.in_groups = in_groups - ctx.group_size = group_size - ctx.penalty_scale = penalty_scale - assert not (in_groups > 0 and out_groups > 0) - return torch.matmul(x, weight.t()) - - @staticmethod - @custom_bwd - def backward(ctx, y_grad): - x, weight = ctx.saved_tensors - - if x.requires_grad: - x_grad = torch.matmul(y_grad, weight) - else: - x_grad = None - - out_groups, in_groups, group_size = ctx.out_groups, ctx.in_groups, ctx.group_size - - if weight.requires_grad: - weight_grad = torch.matmul(y_grad.reshape(-1, y_grad.shape[-1]).t(), - x.reshape(-1, x.shape[-1])) - else: - weight_grad = None - - if weight.requires_grad and ctx.penalty_scale != 0.0: - penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() - - with torch.enable_grad(): - weight = weight.detach() - weight.requires_grad = True - - # Get extra gradient term that penalizes non-orthogonality. - - # First get w which is of shape (num_groups, out_channels_per_group, in_channels_per_group) - if out_groups > 0: - w = weight[:out_groups*group_size].reshape(out_groups, group_size, weight.shape[1]) - elif in_groups > 0: - w = weight[:, :in_groups*group_size].reshape(weight.shape[0], in_groups, group_size).transpose(0, 1) - else: - w = weight.unsqueeze(0) - - - # Compute symmetric matrix-product prod with the smallest - # dimension possible given the shape of w. This is not just for - # efficiency; if we computed it the wrong way round, the product - # would have deficient rank and could never be the identity. - if (w.shape[1] > w.shape[2]): - prod = torch.matmul(w.transpose(1, 2), w) - else: - prod = torch.matmul(w, w.transpose(1, 2)) - - # we'll try to enforce that for any i, prod[i] is any constant times the identity. - - # in the loss-function: - # orthogonality_loss = ((prod * alpha - I) ** 2).sum(), - # the following formula gives the alpha that means d(err)/d(scale-of-prod) will be zero. - # alpha = prod.diag().mean() / (prod ** 2).sum(dim=1).mean(dim=0) - - # note, prod_diag shares memory with prod, this will matter later on. - (groups, r, c) = prod.shape - (groups_stride, r_stride, c_stride) = prod.stride() - - def diag_inplace(z): - return torch.as_strided(z, size=(groups, r), stride=(groups_stride, r_stride+c_stride)) - - with torch.no_grad(): - # alpha: (groups, 1) - alpha = (diag_inplace(prod).mean(dim=1, keepdim=True) / - (prod ** 2).sum(dim=2).mean(dim=1, keepdim=True)) - - prod *= alpha.unsqueeze(-1) - diag_inplace(prod)[:] -= 1. - - # that loss that we want to backprop would be 0.5 * (prod ** - # 2).sum() * penalty_scale. we can backprop this without doing - # any reductions as follows: - prod.backward(gradient=prod * penalty_scale) - - - do_print = random.random() < 0.002 - if do_print: - # we print a normalized version of the loss, by dividing by the - # number of rows. - loss = (prod ** 2).mean(dim=(1,2)) * prod.shape[1] - logging.info(f"OrthogonalLinear: name={ctx.name}, scale={(1. / alpha).sqrt().cpu().flatten()}, loss={loss.detach().cpu().flatten()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") - - - # add the extra gradient term from the orthogonality loss. - weight_grad += weight.grad - return x_grad, weight_grad, None, None, None, None, None - - - -class OrthogonalLinear(nn.Linear): - """ - Like nn.Linear but can enforce that the weight matrix, or selected parts of it, is - orthogonal up to a scalar factor. We are using a generalized definition of "orthogonal" - that applies to non-square matrix, i.e. that either M^T M or M M^T, whichever has - fewer rows/columns, should be equal to the identity times some positive scalar alpha. - (If M is square, these definitions are equivalent and is equivalent to the normal - definition of orthogonal). - - Args: - in_channels: number of input channels - out_channels: number of output channels - in_groups: the number of groups on the input dimension, if specified - the orthogonality-up-to-a-scalar-factor constraint will be - applied separately per group, with different scalars. - out_groups: the number of groups on the output dimension; you cannot - specify both this and in_groups with values >0. - group_size: the number of channels per group. This provides a way - to ensure that only part of the matrix is subject to the - orthogonality constraint, e.g. if you specified in_groups>0, - you can specify group_size - such that in_groups * group_size < in_channels, and the - remaining channels will be unconstrained. - bias: if True, include a bias term. - initial_scale: a factor that allows you to increase or decrease the - initial scale of the weight (and bias, if present) - penalty_scale: a scale on the penalty on non-orthogonality (this will - be multiplied by the average-absolute-value of the - backpropagated gradient). - """ - # if in_groups or out_groups are set to >1, the orthogonal constraint - # will be set per group. both of them cannot be >1. - def __init__(self, - in_channels: int, - out_channels: int, - in_groups: int = -1, - out_groups: int = -1, - group_size: int = -1, - bias: bool = True, - initial_scale: float = 1.0, - penalty_scale: FloatLike = 20.0, - ): - super().__init__(in_channels, out_channels, bias=bias) - self.name = None - self.in_groups = in_groups - self.out_groups = out_groups - if in_groups > 0 and group_size == -1: - group_size = in_channels // in_groups - elif out_groups > 0 and group_size == -1: - group_size = out_channels // out_groups - self.group_size = group_size - self.penalty_scale = copy.deepcopy(penalty_scale) - - # the same scaling as for ScaledLinear. - with torch.no_grad(): - self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * initial_scale - if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.01 * initial_scale, 0.01 * initial_scale) - - - def forward(self, x: Tensor): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return torch.nn.functional.linear(x, self.weight, self.bias) - - ans = OrthogonalLinearFunction.apply(x, self.weight, self.name, - self.in_groups, self.out_groups, - self.group_size, float(self.penalty_scale)) - if self.bias is not None: - ans = ans + self.bias - return ans - - -class MaxVarLossFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, mask: Optional[Tensor], max_var: float, weight: float, name: str): - ctx.save_for_backward(x) - if mask is not None: - assert mask.shape == x.shape[:2], (list(mask.shape), list(x.shape)) - ctx.mask = mask # mask will have no grad so it should be OK to store this way - ctx.name = name - ctx.weight = weight - ctx.max_var = max_var - return torch.tensor(0.0, device=x.device, dtype=x.dtype) - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - x, = ctx.saved_tensors - mask = ctx.mask # optional Tensor - name = ctx.name # str - weight = ctx.weight # float - max_var = ctx.max_var # float - - - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - - eps = 3.0e-08 # won't be zero in float16 - x_var = (x ** 2).mean(dim=-1) - if mask is not None: - mask = (~mask).to(x.dtype) - x_var = x_var * mask - - with torch.amp.autocast('cuda', enabled=False): - x_var = x_var.to(torch.float) - if mask is not None: - numel = mask.sum() - else: - numel = x_var.numel() - excess_var = (x_var.sum() - max_var * numel).relu() - - if random.random() < 0.001: - logging.info(f"MaxVarLoss: {name}, limit={max_var}, excess-var={excess_var.mean() / numel}") - - # scale the loss by less than one, if we are close to the limit. - excess_var = excess_var * (excess_var / (numel * max_var)).clamp(max=1.0) - - # also add a factor of 1. / max_var into the loss scale. - excess_var.backward(gradient=torch.full_like(excess_var, weight * (1. / max_var))) - - return x.grad, None, None, None, None - - -class MaxVarLoss(nn.Module): - def __init__(self, - max_rms: FloatLike): - super().__init__() - self.max_rms = max_rms - self.name = None - - def forward(self, - x: Tensor, - loss_scale: float, - mask: Optional[Tensor] = None) -> Tensor: - """ - Compute loss that acts like a penalty if the mean-square value of x - exceeds self.max_rms**2 - - x: Tensor of shape (batch_size, seq_len, num_channels) - loss_scale: the scale with which the loss should be incorporated into the graph. - This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). - The loss will be summed over frames, and multiplied by this value. - mask: if supplied, mask of shape (batch_size, seq_len); - True means masked positions. - - Returns: - returns a scaled scalar loss value "ret" which should be incorporated - into the backprop graph by doing: - z = with_loss(z, ret, None) - where z is any quantity that will be used in calculating the main loss. - Ret will always be numerically equal to zero in the forward pass but - may behave as if it were nonzero for backprop purposes. - """ - return MaxVarLossFunction.apply(x, mask, - float(self.max_rms) ** 2, - loss_scale, self.name) - - -class CosineSimilarityLossFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, mask: Optional[Tensor], max_similarity: float, weight: float, name: str): - ctx.save_for_backward(x) - if mask is not None: - assert mask.shape == x.shape[:2], (list(mask.shape), list(x.shape)) - ctx.mask = mask # mask will have no grad so it should be OK to store this way - ctx.name = name - ctx.weight = weight - ctx.max_similarity = max_similarity - return torch.tensor(0.0, device=x.device, dtype=x.dtype) - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - x, = ctx.saved_tensors - mask = ctx.mask # optional Tensor - name = ctx.name # str - weight = ctx.weight # float - max_similarity = ctx.max_similarity # float - - - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - - eps = 3.0e-08 # won't be zero in float16 - x_norm = x / ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() - (batch_size, seq_len, num_channels) = x.shape - _, permutation = torch.rand(batch_size, seq_len, device=x.device).sort(dim=1) - # permutation: (batch_size, seq_len) - arange = torch.arange(seq_len, device=x.device) - mask2 = (permutation == arange) - if mask is not None: - mask = torch.logical_or(mask, mask2) - else: - mask = mask2 - x_norm = x_norm * (~mask).unsqueeze(-1).to(x.dtype) - - x_permuted = torch.gather(x_norm, 1, permutation.unsqueeze(-1).expand(*x.shape)) - - similarity = (x_norm * x_permuted).sum(dim=-1).abs() # use absolute value so we penalize negative correlations also - excess_similarity = (similarity.sum(dim=1) - seq_len * max_similarity).relu() - - if random.random() < 0.001: - logging.info(f"CosineSimilarityLoss: {name}, limit={max_similarity}, excess-similarity={excess_similarity.mean() / seq_len}") - - grad = (weight * ans_grad).expand(excess_similarity.numel()) - excess_similarity.backward(grad) - - return x.grad, None, None, None, None - - -class CosineSimilarityLoss(nn.Module): - def __init__(self, - max_similarity: FloatLike): # e.g. 0.1 for max_similarity - super().__init__() - self.max_similarity = max_similarity - self.name = None - - def forward(self, - x: Tensor, - loss_scale: float, - mask: Optional[Tensor] = None) -> Tensor: - """ - Compute cosine-similarity loss that tries to make sure distinct output vectors - have inner products with small magnitude (after normalization), i.e. the cosine - of the angle between should be close to zero. - - x: Tensor of shape (batch_size, seq_len, num_channels) - loss_scale: the scale with which the loss should be incorporated into the graph. - This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). - The loss will be summed over frames, and multiplied by this value. - mask: if supplied, mask of shape (batch_size, seq_len); - True means masked positions. - - Returns: - returns a scaled scalar loss value "ret" which should be incorporated - into the backprop graph by doing: - z = with_loss(z, ret, None) - where z is any quantity that will be used in calculating the main loss. - Ret will always be numerically equal to zero in the forward pass but - may behave as if it were nonzero for backprop purposes. - """ - return CosineSimilarityLossFunction.apply(x, mask, - float(self.max_similarity), - loss_scale, self.name) - - - -class SimpleOrthogonalPenaltyFunction(torch.autograd.Function): +class OrthogonalPenaltyFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, weight: Tensor, penalty_scale: float, name: str): @@ -1325,7 +605,7 @@ def diag_inplace(z): weight_grad = weight_grad + weight.grad return weight_grad, None, None -class SimpleOrthogonalLinear(nn.Linear): +class OrthogonalLinear(nn.Linear): """ Like nn.Linear but can enforce that the weight matrix is orthogonal; in the non-square case this is interpreted as either M^T M == I or M M^T == I, whichever would give a smaller @@ -1352,7 +632,7 @@ def __init__(self, out_channels: int, lr_scale: float = 1.0, bias: bool = True, - penalty_scale: FloatLike = 20.0, + penalty_scale: float = 20.0, ): super().__init__(in_channels, out_channels, bias=bias) self.name = None @@ -1372,424 +652,12 @@ def forward(self, x: Tensor, transpose: bool = False): if lr_scale != 1.0: weight = weight * lr_scale if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): - weight = SimpleOrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) + weight = OrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) if transpose: weight = weight.t() return torch.nn.functional.linear(x, weight, self.bias) -def get_max_similarity(rank: int, power: float): - """ - For use when initializing CosineSimilarityLoss, this returns a value for - the "max_similarity" argument. - max_similarity is an upper limit we impose on the mean value of (x_i . x_j), - where i != j are two different sequence-position indexes and x_i and x_j are - activation vectors normalized to have unit length. - - rank: the dimension of the space, usually this is the num_channels, but if - we have just up-projected from a bottleneck, it would be the bottleneck - dimension. - power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean - we enforce the vector dimensions to be completely independent like Gaussian noise - (don't do this); if we set power=0.0 it would be equivalent to not having - the CosineSimilarityLoss at all. - - The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal - variable. If x consists of independent Gaussian noise of dimension D, with - variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" - would be close to a no-op for large D), then (x_i . x_j) would be distributed as - a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) - would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value - between this and 1, as a kind of heuristic limit on this max_similarity. - """ - return (0.7978845608 / (rank ** 0.5)) ** power - - -class MinProductLossFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, y: Tensor, mask: Optional[Tensor], min_product: float, weight: float, name: str): - ctx.save_for_backward(x, y) - ctx.mask = mask # mask will have no grad so it should be OK to store this way - ctx.name = name - ctx.weight = weight - ctx.min_product = min_product - return torch.tensor(0.0, device=x.device, dtype=x.dtype) - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - x, y = ctx.saved_tensors - mask = ctx.mask # optional Tensor - name = ctx.name # str - weight = ctx.weight # float - min_product = ctx.min_product # float - - - with torch.enable_grad(): - x, y = x.detach(), y.detach() - x.requires_grad = True - y.requires_grad = True - - eps = 3.0e-08 # won't be zero in float16 - x_norm = x / ((x ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() - y_norm = y / ((y ** 2).sum(dim=-1, keepdim=True) + eps).sqrt() - (batch_size, seq_len, num_channels) = x.shape - - - - product = x_norm * y_norm - product = product.sum(dim=-1) - if mask is not None: - inv_mask = (~mask).to(x.dtype) - product = product * inv_mask - - if mask is not None: - product_deficit = (inv_mask.sum(dim=1) * min_product - product.sum(dim=1)).relu() - else: - product_deficit = (seq_len * min_product - product.sum(dim=1)).relu() - - if random.random() < 0.005: - logging.info(f"MinProductLoss: {name}, limit={min_product}, product-deficit={product_deficit.mean() / seq_len}") - - grad = (weight * ans_grad).expand(product_deficit.numel()) - product_deficit.backward(grad) - - return x.grad, y.grad, None, None, None, None - -class MinProductLoss(nn.Module): - def __init__(self, - min_product: FloatLike): # e.g. 0.5 for min_product - super().__init__() - self.min_product = min_product - self.name = None - - def forward(self, - x: Tensor, - y: Tensor, - loss_scale: float, - mask: Optional[Tensor] = None) -> Tensor: - """ - Compute loss that tries to keep two embeddings in similar directions, used to - make sure that the bulk of the embedding goes through one branch. - - x: Tensor of shape (batch_size, seq_len, num_channels) - y: Tensor of shape (batch_size, seq_len, num_channels) - loss_scale: the scale with which the loss should be incorporated into the graph. - This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). - The loss will be summed over frames, and multiplied by this value. - mask: if supplied, mask of shape (batch_size, seq_len); - True means masked positions that will be ignored. - - Returns: - returns a scaled scalar loss value "ret" which should be incorporated - into the backprop graph by doing: - z = with_loss(z, ret, None) - where z is any quantity that will be used in calculating the main loss. - Ret will always be numerically equal to zero in the forward pass but - may behave as if it were nonzero for backprop purposes. - """ - return MinProductLossFunction.apply(x, y, mask, - float(self.min_product), - loss_scale, self.name) - - -class MinProductLoss(nn.Module): - def __init__(self, - min_product: FloatLike): # e.g. 0.5 for min_product - super().__init__() - self.min_product = min_product - self.name = None - - def forward(self, - x: Tensor, - y: Tensor, - loss_scale: float, - mask: Optional[Tensor] = None) -> Tensor: - """ - Compute loss that tries to keep two embeddings in similar directions, used to - make sure that the bulk of the embedding goes through one branch. - - x: Tensor of shape (batch_size, seq_len, num_channels) - y: Tensor of shape (batch_size, seq_len, num_channels) - loss_scale: the scale with which the loss should be incorporated into the graph. - This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). - The loss will be summed over frames, and multiplied by this value. - mask: if supplied, mask of shape (batch_size, seq_len); - True means masked positions that will be ignored. - - Returns: - returns a scaled scalar loss value "ret" which should be incorporated - into the backprop graph by doing: - z = with_loss(z, ret, None) - where z is any quantity that will be used in calculating the main loss. - Ret will always be numerically equal to zero in the forward pass but - will behave as if it were nonzero for backprop purposes. - """ - return MinProductLossFunction.apply(x, y, mask, - float(self.min_product), - loss_scale, self.name) - - -# cross cosine loss is for when you have a situation like: -# y = y + delta -# y = with_loss(y, cross_cosine_loss(x, y, delta)) -# and we want to make sure that adding delta does not change the magnitude -# of individual embedding vectors very much. -# we do this by making sure that mean(abs(log(|x_i|) - log(|y_i|))) <= limit. -class NormChangeLossFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, y: Tensor, mask: Optional[Tensor], - limit: float, weight: float, name: str): - ctx.save_for_backward(x, y) - ctx.name = name - ctx.mask = mask # mask will have no grad so it should be OK to store this way - ctx.weight = weight - ctx.limit = limit - # return fake loss that is always zero but behaves in backprop as if it were a real loss. - return torch.tensor(0.0, device=x.device, dtype=x.dtype) - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - x, y = ctx.saved_tensors - name = ctx.name # str - mask = ctx.mask # Tensor or None, shape: (batch_size, seq_len) - weight = ctx.weight # float - limit = ctx.limit # float - (batch_size, seq_len, num_channels) = x.shape - - with torch.enable_grad(): - with torch.amp.autocast('cuda', enabled=False): - x, y = x.to(torch.float), y.to(torch.float) - x, y = x.detach(), y.detach() - x.requires_grad = True - y.requires_grad = True - eps = 1.0e-10 - x_sqnorm = (x * x).sum(dim=-1) + eps - y_sqnorm = (y * y).sum(dim=-1) + eps - norm_diff = 0.5 * (x_sqnorm.log() - y_sqnorm.log()).abs() - - if mask is not None: - norm_diff = norm_diff * (~mask).to(norm_diff.dtype) - - excess_norm_diff = (norm_diff.sum(dim=1) - seq_len * limit).relu() - - if random.random() < 0.001: - logging.info(f"NormChangeLoss: {name}, limit={limit}, excess-norm-diff={excess_norm_diff.mean() / seq_len}") - - grad = (weight * ans_grad).expand(excess_norm_diff.numel()) - excess_norm_diff.backward(grad) - - return x.grad, y.grad, None, None, None, None - -class NormChangeLoss(nn.Module): - def __init__(self, - limit: FloatLike): # e.g. 0.2. - super().__init__() - self.limit = limit - self.name = None - - def forward(self, - x: Tensor, - y: Tensor, - loss_scale: float, - mask: Optional[Tensor]) -> Tensor: - """ - Compute loss that limits the average value over the sequence of abs((delta . x) / (x . x)) - - - x: Tensor of shape (batch_size, seq_len, num_channels) - y: Tensor of shape (batch_size, seq_len, num_channels) - loss_scale: the scale with which the loss should be incorporated into the graph. - This should contain a factor of the grad_scale, if you are using GradScaler for - automatic mixed precision training (amp). - The loss will be summed over frames of x, i.e. scaled like - batch_size * seq_len * loss_scale * [average excess product] - - Returns: - returns a scaled scalar loss value "ret" which should be incorporated - into the backprop graph by doing: - z = with_loss(z, ret, None) - where z is any quantity that will be used in calculating the main loss. - Ret will always be numerically equal to zero in the forward pass but - will behave as if it were nonzero for backprop purposes. - """ - limit = float(self.limit) - return NormChangeLossFunction.apply(x, y, mask, limit, - loss_scale, self.name) - - -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - - def __init__( - self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True, - ): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True, - ) - - self.chunkwise_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias, - ) - - # first row is correction factors added to the scale near the left edge of the chunk, - # second row is correction factors added to the scale near the right edge of the chunk, - # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - self.left_pad = half_kernel_size - 1 - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_( - self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) - - def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.left_pad - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., : left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( - batch_size * num_chunks, num_channels, chunk_size - ) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape( - batch_size, num_chunks, num_channels, chunk_size - ).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ - ..., :seq_len - ] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros( - channels, t, device=left_edge.device, dtype=left_edge.dtype - ) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Streaming Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - cache: cached left context of shape (batch_size, channels, left_pad) - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.left_pad - - # Pad cache - assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -left_pad:] - - x_causal = self.causal_conv(x) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size=seq_len) - x_chunk = x_chunk * chunk_scale - - return x_chunk + x_causal, cache - - class ScaleLimiterFunction(torch.autograd.Function): @staticmethod @@ -1830,7 +698,7 @@ class ScaleLimiter(torch.nn.Module): Assumes channel dim is -1 and the input shape has >1 dimension. """ - def __init__(self, max_rms: FloatLike): + def __init__(self, max_rms: float): super().__init__() self.name = None self.max_rms = max_rms @@ -1924,7 +792,7 @@ class CorrelationLimiter(torch.nn.Module): Assumes input is (batch, seq, channel) """ - def __init__(self, limit: FloatLike = 0.03): + def __init__(self, limit: float = 0.03): super().__init__() self.name = None self.limit = limit @@ -2026,114 +894,6 @@ def _whitening_metric(x: Tensor, num_groups: int): return metric -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.amp.autocast('cuda', enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" - ) - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = float(w.grad_scale) * ( - x_grad.to(torch.float32).norm() - / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info( - f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." - ) - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float, float]], - grad_scale: FloatLike, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert float(grad_scale) >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) class WithLoss(torch.autograd.Function): @@ -2242,56 +1002,6 @@ def forward(self, x): -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - (ans,) = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - def torch_compile(fn, *args, **kwargs): @@ -2348,26 +1058,8 @@ def forward(self, x: Tensor) -> Tensor: return self.func(x) -class SquareLogSoftmax(nn.Module): - def __init__(self, dim: int = -1, eps: float = 1.0e-03): - super().__init__() - self.dim = dim - self.eps = eps - - def forward(self, x: Tensor): - dim = self.dim - eps = self.eps - with torch.amp.autocast('cuda', enabled=False): - x = x.to(torch.float) - channels = x.shape[dim] - x_sq = x ** 2 - x = (x_sq + eps/channels) / (x_sq.sum(dim=dim, keepdim=True) + eps) - return x.log() - - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): +class ActivationAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward( @@ -2377,27 +1069,12 @@ def forward( bias: Optional[Tensor], forward_func: Any, backward_func: Any, - dropout_p: float, - dropout_shared_dim: Optional[int], ): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = (1.0 / (1.0 - dropout_p)) * ( - torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p - ) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) + ctx.save_for_backward(x, weight, bias) ctx.backward_func = backward_func x = forward_func(x) - if dropout_mask is not None: - x = x * dropout_mask x = torch.nn.functional.linear(x, weight, bias) return x @@ -2405,11 +1082,9 @@ def forward( @custom_bwd def backward(ctx, ans_grad: Tensor): saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved + (x, weight, bias) = saved y, func_deriv = ctx.backward_func(x) - if dropout_mask is not None: - y = y * dropout_mask # now compute derivative of y w.r.t. weight and bias.. # y: (..., in_channels), ans_grad: (..., out_channels), (out_channels, in_channels) = weight.shape @@ -2420,38 +1095,25 @@ def backward(ctx, ans_grad: Tensor): y_deriv = torch.matmul(ans_grad, weight) bias_deriv = None if bias is None else g.sum(dim=0) x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None, None + return x_deriv, weight_deriv, bias_deriv, None, None -class ActivationDropoutAndLinear(torch.nn.Module): +class ActivationAndLinear(torch.nn.Module): """ - This merges an activation function followed by dropout and then a nn.Linear module; + This merges an activation function followed by a nn.Linear module; it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwashL and dropout_shared_dim != None, this will be + module. If activation == SwashL, this will be equivalent to: nn.Sequential(SwashL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), ScaledLinear(in_channels, out_channels, bias=bias, initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. Args: in_channels: number of input channels, e.g. 256 out_channels: number of output channels, e.g. 256 bias: if true, have a bias activation: the activation function, for now just support SwashL, SwashR. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). """ def __init__( self, @@ -2459,8 +1121,6 @@ def __init__( out_channels: int, bias: bool = True, activation: str = "SwashL", - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, initial_scale: float = 1.0, ): super().__init__() @@ -2478,8 +1138,6 @@ def __init__( self.register_parameter("bias", l.bias) self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim assert activation in ["SwashL", "SwashR"] if activation == "SwashL": @@ -2495,14 +1153,12 @@ def forward(self, x: Tensor): x = self.forward_func(x) return torch.nn.functional.linear(x, self.weight, self.bias) - return ActivationDropoutAndLinearFunction.apply( + return ActivationAndLinearFunction.apply( x, self.weight, self.bias, self.forward_func, self.backward_func, - float(self.dropout_p), - self.dropout_shared_dim, ) @@ -2516,45 +1172,6 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: return torch.cat((x, zeros), dim=-1) -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - def _test_swashl_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 @@ -2596,64 +1213,30 @@ def _test_softmax(): assert torch.allclose(a.grad, b.grad) -def _test_piecewise_linear(): - p = PiecewiseLinear((0, 10.0)) - for x in [-100, 0, 100]: - assert p(x) == 10.0 - p = PiecewiseLinear((0, 10.0), (1, 0.0)) - for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: - print("x, y = ", x, y) - assert p(x) == y, (x, p(x), y) - - q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] - pq = p.max(q) - for x in x_vals: - y1 = max(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p.min(q) - for x in x_vals: - y1 = min(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p + q - for x in x_vals: - y1 = p(x) + q(x) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - - -def _test_activation_dropout_and_linear(): +def _test_activation_and_linear(): in_channels = 20 out_channels = 30 for bias in [True, False]: - # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because we are using the k2 implementation of - # swash_l an swash_r inside SwashL() and SwashR(), and they call randn() - # internally, messing up the random state. - for dropout_p in [0.0]: + if True: for activation in ["SwashL", "SwashR"]: m1 = nn.Sequential( SwashL() if activation == "SwashL" else SwashR(), - Dropout3(p=dropout_p, shared_dim=-1), ScaledLinear( in_channels, out_channels, bias=bias, initial_scale=0.5 ), ) - m2 = ActivationDropoutAndLinear( + m2 = ActivationAndLinear( in_channels, out_channels, bias=bias, initial_scale=0.5, activation=activation, - dropout_p=dropout_p, ) with torch.no_grad(): - m2.weight[:] = m1[2].weight + m2.weight[:] = m1[1].weight if bias: - m2.bias[:] = m1[2].bias + m2.bias[:] = m1[1].bias # make sure forward gives same result. x1 = torch.randn(10, in_channels) x1.requires_grad = True @@ -2671,17 +1254,17 @@ def _test_activation_dropout_and_linear(): y2.backward(gradient=y_grad) print( - f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + f"bias = {bias}, activation = {activation}" ) print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - print("grad1 = ", m1[2].weight.grad) + print("grad1 = ", m1[1].weight.grad) print("grad2 = ", m2.weight.grad) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + assert torch.allclose(m1[1].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + assert torch.allclose(m1[1].bias.grad, m2.bias.grad, atol=1.0e-05) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) @@ -2695,24 +1278,18 @@ def isclose(a, b): # storage of it. assert isclose(x1.grad, x2.grad) + def _test_orthogonal_linear(): m = OrthogonalLinear(128, 128) m(torch.randn(30, 2, 128)) -def _test_simple_orthogonal_linear(): - m = SimpleOrthogonalLinear(128, 128) - m(torch.randn(30, 2, 128)) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_piecewise_linear() _test_softmax() - _test_whiten() _test_swashr_deriv() _test_swashl_deriv() - _test_activation_dropout_and_linear() + _test_activation_and_linear() _test_orthogonal_linear() - _test_simple_orthogonal_linear() diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 49ab427764..cc8a52e900 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -24,14 +24,10 @@ from scaling import ( ScaleLimiter, ScaledLinear, - FloatLike, - get_max_similarity, ScaledConv2d, ScaleGrad, - ScheduledFloat, SwashL, SwashR, - CosineSimilarityLoss, with_loss, ) from torch import Tensor, nn @@ -72,7 +68,7 @@ def __init__( padding = (kernel_size[0] // 2, kernel_size[1] // 2) else: padding = (0, kernel_size[1] // 2) - self.left_pad = kernel_size[0] - 1 + self.left_pad = kernel_size[0] - 1 self.depthwise_conv = nn.Conv2d( in_channels=channels, @@ -116,7 +112,7 @@ def forward( x = bypass + x return x - + def streaming_forward( self, x: Tensor, @@ -339,58 +335,58 @@ def get_init_cache( def _test_conv2d_subsampling_streaming(): logging.info("Testing Conv2dSubsampling streaming equivalence...") - + batch_size = 2 idim = 80 odim = 256 - + model = Conv2dSubsampling( in_channels=idim, out_channels=odim, causal=True ) - + model.eval() out_chunk_size = 32 - in_chunk_size = out_chunk_size * 2 + 7 - in_shift = out_chunk_size * 2 - + in_chunk_size = out_chunk_size * 2 + 7 + in_shift = out_chunk_size * 2 + num_chunks = 10 - + seq_len = num_chunks * in_shift + 7 - + x_full = torch.randn(batch_size, seq_len, idim) x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) - + with torch.no_grad(): out_full, out_lens_full = model(x_full, x_lens_full) - + cache = model.get_init_cache(batch_size=batch_size) - + out_chunks = [] out_offset = 0 - + for i in range(num_chunks): start = i * in_shift end = start + in_chunk_size x_chunk = x_full[:, start:end, :] x_lens_chunk = torch.full((batch_size,), in_chunk_size, dtype=torch.int64) - + out_chunk, out_lens_chunk, cache = model.streaming_forward( x_chunk, x_lens_chunk, cache ) out_chunks.append(out_chunk) - + out_chunk_len = out_chunk.shape[1] expected_out = out_full[:, out_offset : out_offset + out_chunk_len, :] - + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") - + assert torch.allclose(expected_out, out_chunk, atol=1e-4), f"Chunk {i+1} mismatch! max diff: {diff_chunk}" out_offset += out_chunk_len - + out_stream_cat = torch.cat(out_chunks, dim=1) diff_total = torch.max(torch.abs(out_full - out_stream_cat)) logging.info(f"Total Max Diff between full forward and streaming: {diff_total}") @@ -403,4 +399,4 @@ def _test_conv2d_subsampling_streaming(): logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_conv2d_subsampling_streaming() \ No newline at end of file + _test_conv2d_subsampling_streaming() diff --git a/egs/librispeech/ASR/zipformer/test_subsampling.py b/egs/librispeech/ASR/zipformer/test_subsampling.py index 078227fb68..b502d5a773 100755 --- a/egs/librispeech/ASR/zipformer/test_subsampling.py +++ b/egs/librispeech/ASR/zipformer/test_subsampling.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 import torch -from scaling import ScheduledFloat from subsampling import Conv2dSubsampling - +# TODO: fix, this does not work right tnow def test_conv2d_subsampling(): layer1_channels = 8 layer2_channels = 32 @@ -17,7 +16,6 @@ def test_conv2d_subsampling(): layer1_channels=layer1_channels, layer2_channels=layer2_channels, layer3_channels=layer3_channels, - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), ) N = 2 T = 200 diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7805266f2e..dbb69a4ac4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -26,18 +26,14 @@ import torch from encoder_interface import EncoderInterface from scaling import ( - ActivationDropoutAndLinear, - ChunkCausalDepthwiseConv1d, - CosineSimilarityLoss, + ActivationAndLinear, CorrelationLimiter, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. OrthogonalLinear, RmsNorm, SequenceNorm, - SimpleOrthogonalLinear, + OrthogonalLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - ScheduledFloat, - FloatLike, SwashR, convert_num_channels, limit_param_value, @@ -720,8 +716,8 @@ def __init__( super().__init__() # self.downsample will also reverse the downsampling operation for us afterward. - self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, - lr_scale=0.66, bias=False) + self.proj = OrthogonalLinear(dim, encoder_layer.embed_dim, + lr_scale=0.66, bias=False) self.name = None self.layers = nn.ModuleList( @@ -895,7 +891,7 @@ class ResidualModule(nn.Module): def __init__( self, embed_dim: int, - function_scale_min: FloatLike = 0.1, + function_scale_min: float = 0.1, ): super().__init__() self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) @@ -1052,13 +1048,11 @@ def __init__( query_head_dim: int, pos_head_dim: int = 4, value_head_dim: int = 12, - dropout: float = 0.0, ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim - self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. self.in_norm = RmsNorm() @@ -1188,10 +1182,6 @@ def forward( elif random.random() < 0.001: self._print_attn_entropy(attn_weights) - # note: self.dropout is normally 0.0. - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) @@ -1393,11 +1383,9 @@ def __init__(self, embed_dim: int, feedforward_dim: int): # to the TransformedAdam optimizer. self.in_proj.weight_min_rms = 0.02 - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( + self.out_proj = ActivationAndLinear( feedforward_dim, embed_dim, - dropout_p=0.0, activation="SwashR", initial_scale=0.5, bias=True, @@ -1691,11 +1679,10 @@ def __init__( self.depthwise_conv.lr_scale = 0.66 - self.out_proj = ActivationDropoutAndLinear( + self.out_proj = ActivationAndLinear( bottleneck_dim, channels, activation="SwashR", - dropout_p=0.0, initial_scale=0.05, ) From 075b702760f82e4339855440dc27948b4c11947a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 14 Mar 2026 12:02:56 +0800 Subject: [PATCH 0970/1191] Increase CorrelationLimiter limit from .35 to .45 --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7805266f2e..a7f1f60ec0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -540,7 +540,7 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) - power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn = MultiheadRelPosGatedSelfAttention( From f8db8375813016297b418bfd69a6988ac505e706 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 14 Mar 2026 23:46:01 +0800 Subject: [PATCH 0971/1191] Replace CosineLRScheduler with HalfCosineLRScheduler --- .../ASR/zapformer/combined_scheduler.py | 18 ++++++++++++++++++ egs/librispeech/ASR/zapformer/train.py | 10 +++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 6a91a312cc..36662ac7e2 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -178,6 +178,24 @@ def get_lr(self): return [x * factor for x in self.base_lrs] +class HalfCosineLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + **kwargs): + """ + Cosine learning rate scheduler consisting of cosine from 0 to pi/2 with no offset, + that inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). Equivalent to sqrt of normal cosine + LR schedule. + """ + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + factor = math.cos((math.pi / 2) * progress) + return [x * factor for x in self.base_lrs] + + class LinearLRScheduler(CombinedLRScheduler): def __init__(self, *args, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6808cf8a87..e77a6874bf 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,9 +76,9 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import TransformedAdam -from combined_scheduler import CombinedLRScheduler, CosineLRScheduler +from combined_scheduler import CombinedLRScheduler try: - from combined_scheduler import CosineLRScheduler + from combined_scheduler import HalfCosineLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -1377,9 +1377,9 @@ def lr_lambda(current_step): progress = current_step / total_steps return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - scheduler = CosineLRScheduler(optimizer, - batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs) + scheduler = HalfCosineLRScheduler(optimizer, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From f3fd4d88a451ac59d96729d16346551fd0aa595d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 16 Mar 2026 11:33:09 +0800 Subject: [PATCH 0972/1191] Implement InterpCosineLRScheduler --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 13 ++++++++----- egs/librispeech/ASR/zapformer/train.py | 8 ++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 36662ac7e2..cd6a2822ad 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -178,21 +178,24 @@ def get_lr(self): return [x * factor for x in self.base_lrs] -class HalfCosineLRScheduler(CombinedLRScheduler): +class InterpCosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, **kwargs): """ - Cosine learning rate scheduler consisting of cosine from 0 to pi/2 with no offset, - that inherits from CombinedLRScheduler (see its documentation - to understand general aspects of usage). Equivalent to sqrt of normal cosine - LR schedule. + This cosine LR scheduler is halfway between the conventional cosine LR scheduler + that takes the cosine from 0 to pi, and one that takes the cosine from 0 to pi/2. + It inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). """ super().__init__(*args, **kwargs) def get_lr(self): progress = self.get_progress() factor = math.cos((math.pi / 2) * progress) + # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate + # between the two. + factor = 0.5 * (factor + factor ** 2) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index e77a6874bf..2cfff762f5 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -78,7 +78,7 @@ from optim import TransformedAdam from combined_scheduler import CombinedLRScheduler try: - from combined_scheduler import HalfCosineLRScheduler + from combined_scheduler import InterpCosineLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -1377,9 +1377,9 @@ def lr_lambda(current_step): progress = current_step / total_steps return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - scheduler = HalfCosineLRScheduler(optimizer, - batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs) + scheduler = InterpCosineLRScheduler(optimizer, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 3cd164524d541403aa1174d4a2c0278cf66b8469 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 16 Mar 2026 11:54:35 +0800 Subject: [PATCH 0973/1191] Some configuration changes; CorrelationLimiter power 0.45->0.4, cubic_decay_proportion=0.85 to cubic_decay_proportion=0.8, beta1=0.998 to beta1=0.995. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 2cfff762f5..d3640dd037 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1364,9 +1364,9 @@ def run(rank, world_size, args): get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, direct=0.15, - cubic_decay_proportion=0.85, + cubic_decay_proportion=0.8, wd=18, - beta1=0.998, + beta1=0.995, scale_limits=(1.0, 4.0), ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index a7f1f60ec0..f1b05f3b0c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -540,7 +540,7 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) - power = 0.45 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn = MultiheadRelPosGatedSelfAttention( From 56d283dfe99d90f7642a1d468ecaa14d2fedbdf2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 16 Mar 2026 15:51:30 +0800 Subject: [PATCH 0974/1191] Change where padding is done in ConvolutionModule and round up sequence length a little --- egs/librispeech/ASR/zipformer/zipformer.py | 42 ++++++++++++++++------ 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index f1b05f3b0c..fd92163651 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1595,6 +1595,17 @@ def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: return pos_weights +def round_up_to_power_of_two(x): + x = x - 1 + x = x | x >> 1 + x = x | x >> 2 + x = x | x >> 4 + x = x | x >> 8 + x = x | x >> 16 + x = x + 1 + return x + + class FftConv(nn.Module): def __init__(self, num_channels: int, @@ -1612,10 +1623,18 @@ def forward(self, x: Tensor) -> Tensor: (seq_len, batch_size, num_channels) = x.shape + # select a power of two that's >= seq_len // 8 and round up seq_len + # to a multiple of that power. This means that rounded_seq_len + # will be of the form (2**n) * k where k <= 8, so it won't contain + # many factors other than two; this will make the FFT more efficient + # without adding an excessive amount of padding. + power_of_two = max(1, round_up_to_power_of_two(seq_len // 8)) + rounded_seq_len = power_of_two * ((seq_len + power_of_two - 1) // power_of_two) + with torch.amp.autocast('cuda', enabled=False): # do it in float32 because non power of two seq_len is not supported in half precision. - x = torch.fft.rfft(x.to(torch.float32), dim=0) + x = torch.fft.rfft(x.to(torch.float32), dim=0, n=rounded_seq_len) # x: (num_freqs, batch_size, num_channels) N = x.shape[0] # num freqs weight = 4. * self.weight @@ -1628,7 +1647,9 @@ def forward(self, # weight: (N, num_channels) weight = weight.unsqueeze(1) # (N, 1, num_channels) x = x * weight - x = torch.fft.irfft(x, n=seq_len, dim=0) + x = torch.fft.irfft(x, n=rounded_seq_len, dim=0) + + x = x[:seq_len] try: x = x + self.bias @@ -1716,12 +1737,6 @@ def forward( Returns: Tensor: Output tensor (#time, batch, channels). """ - # x: (time, batch, channels) - # Caution: this module is not completely - # invariant to the number of frames each sequence is padded with, since - # the FFT-based convolution treats the signal as repeating. - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) @@ -1732,13 +1747,18 @@ def forward( x = x * s x = self.activation2(x) # identity + + # x: (time, batch, channels) + # Caution: this module is not completely + # invariant to the number of frames each sequence is padded with, since + # the FFT-based convolution treats the signal as repeating. + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + if self.causal: # Not support exporting a model for simulated streaming decoding assert not torch.jit.is_scripting() and not torch.jit.is_tracing() x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) - # for the causal version, we don't use fft-conv - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x_shape = x.shape x = torch.nn.functional.pad(x, (self.left_pad, 0)) x = self.depthwise_conv(x) From 8b1ed2efdb5d79004570d1b883f5a988ddd5eb05 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Mar 2026 18:16:50 +0800 Subject: [PATCH 0975/1191] Simplify the interface of model.py, moving SpecAug augmentation out into the main training loop --- .../ASR/zapformer/asr_datamodule.py | 8 +- egs/librispeech/ASR/zapformer/model.py | 70 +----------- .../ASR/zapformer/multicopy_dataset.py | 4 +- egs/librispeech/ASR/zapformer/train.py | 105 ++++++++++++------ icefall/utils.py | 21 +++- 5 files changed, 93 insertions(+), 115 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 71585227c4..853c14c7c7 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -33,11 +33,11 @@ SimpleCutSampler, ) # MulticopyDataset is a modified version of K2SpeechRecognitionDataset from -# lhotse.dataset, modified to, in training mode, to return a batch that has 3 -# different copies of the same data with the last two having different Musan +# lhotse.dataset, modified to, in training mode, to return a batch that has 2 +# different copies of the same data having different Musan # augmentations and the first having none; and also include the key "num_copies" -# in the batch which would be 1 for the validation data (no Musan) and 3 for the -# training data with musan. +# in the batch which would be 1 for the validation data (no Musan) and 2 for the +# different copies of the training data with musan. try: from multicopy_dataset import MulticopyDataset # interface like K2SpeechRecognitionDataset except: diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index cc96669bc4..b807876447 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -24,7 +24,7 @@ from torch import Tensor from encoder_interface import EncoderInterface from scaling import ScaledLinear, convert_num_channels, PredictLoss -from icefall.utils import add_sos, make_pad_mask, time_warp +from icefall.utils import add_sos, make_pad_mask class AsrModel(nn.Module): @@ -362,10 +362,6 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - spec_augment: Optional[nn.Module] = None, - supervision_segments: Optional[torch.Tensor] = None, - time_warp_factor: Optional[int] = 80, - num_copies: int = 1, aux_loss_scale: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -387,20 +383,6 @@ def forward( lm_scale: The scale to smooth the loss with lm (output of predictor network) part - spec_augment: - The SpecAugment instance, or similar/compatible object, that masks - log-mel features. - supervision_segments: - An int tensor of shape ``(S, 3)``. ``S`` is the number of - supervision segments that exist in ``features``. Used only for - time-warping, if num_copies > 1. - time_warp_factor: - Parameter for the time warping; larger values mean more warping. - Set to ``None``, or less than ``1``, to disable. - Used only if num_copies > 1, corresponds to training mode. - num_copies: - the number of copies of the same data that are in the batch, e.g. 1, 2 - or 3; affects CRCTC, spec-augment, etc. aux_loss_scale: auxiliary-loss scale, for scaling cosine losses in the encoders. sc_prob: @@ -426,56 +408,6 @@ def forward( device = x.device - if num_copies > 1: - assert num_copies == 3 # for now. - # will do SpecAugment or similar. - assert spec_augment is not None and getattr(spec_augment, 'time_warp_factor', -1) < 0 - - (batch_size, seq_len, num_channels) = x.shape - B = batch_size // num_copies - x = x.reshape(num_copies, B, seq_len, num_channels) - - do_time_warp = True - if do_time_warp: - shared_time_warp = False - if shared_time_warp: - # Apply time warping. First append the copies on the channel - # dimension so all copies get the exact same time-warping. - x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) - else: - x = x.reshape(num_copies * B, seq_len, num_channels) - - assert supervision_segments is not None - with torch.amp.autocast('cuda', enabled=False): - x = time_warp( - x.to(torch.float), - time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments[:x.shape[0]], - ) - if shared_time_warp: - x = x.reshape(B, seq_len, num_copies, num_channels) - x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) - else: - x = x.reshape(num_copies, B, seq_len, num_channels) - - - # x_no_specaug is several repeats of the 1st copy of the data, which - # is the one not augmented with Musan. But it does have time - # warping and mel warping. - x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( - B * (num_copies - 1), seq_len, num_channels) - - - # Independently apply frequency masking and time masking to all but the first - # copy of the data. - x = spec_augment(x[1:].reshape(-1, seq_len, num_channels)) - - x_lens = x_lens[:B*(num_copies-1)] - y = y[:B*(num_copies-1)] - else: - x_no_specaug = x - - # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens, aux_loss_scale=aux_loss_scale) diff --git a/egs/librispeech/ASR/zapformer/multicopy_dataset.py b/egs/librispeech/ASR/zapformer/multicopy_dataset.py index f27360ad0b..ffac8b04af 100755 --- a/egs/librispeech/ASR/zapformer/multicopy_dataset.py +++ b/egs/librispeech/ASR/zapformer/multicopy_dataset.py @@ -125,8 +125,8 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]] for tnfm in self.cut_transforms: cuts = tnfm(cuts) - cuts = orig_cuts + cuts - num_copies = 3 + #cuts = orig_cuts + cuts + num_copies = 2 else: num_copies = 1 diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d3640dd037..640bd147cb 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -74,7 +74,7 @@ from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import AsrModel +from model2 import AsrModel from optim import TransformedAdam from combined_scheduler import CombinedLRScheduler try: @@ -108,6 +108,7 @@ get_parameter_groups_with_lrs, setup_logger, str2bool, + time_warp, ) @@ -891,13 +892,62 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) + +def augmentation( + features: Tensor, + feature_lens: Tensor) -> Tensor: + """ + Does augmentation; if need_unaugmented_features returns (augmented_features, unaugmented_features), + else (augmented_features, None) + + Args: + params: command-lines options + num_copies: the number of copies of the data in "feature", expected to be 3, consisting of + (noise_augmentation_copy1, noise_augmentation_copy2, no_noise_augmentation). + features: a Tensor of shape (batch_size, seq_len, num_channels), with batch_size + expected to be a multiple of num_copies, with 3 versions of the minibatch appended + with torch.cat((aug1, aug2, original), dim=0) + + Returns: + (augmented_features, unaugmented_features). + + augmented_features: feature with SpecAug, of shape (2 * batch_size // 3, seq_len, num_channels) + unaugmented_features: if need_unaugmented_features, of shape (2 * batch_size // 3, seq_len, num_channels); + else, None. Note: these features will actually include any time-warping, based on the assumption + that this needs to be kept in sync. + """ + assert num_copies in [1, 3] + (batch_size, seq_len, num_channels) = x.shape + B = batch_size // num_copies + x = x.reshape(num_copies, B, seq_len, num_channels) + + do_time_warp = True + + if do_time_warp: + with torch.amp.autocast('cuda', enabled=False): + x = time_warp( + x.to(torch.float), + time_warp_factor=80, + feature_lens=feature_lens, + ) + + # note: ExpAugment() does *somewhat* assume that x consists of two copies of + # the same data, but practically speaking the only important use this is put + # to is that it chooses non-overlapping frequency regions to mask. it also + # chooses non-overlapping time regions to mask, but this is not so important + # since the time warping (if used) was done independently on the two copies. + spec_augment = ExpAugment() + x = spec_augment(x) + + return x + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - spec_augment: Optional[nn.Module] = None, aux_loss_scale: float = 0.0, ) -> Tuple[Tensor, MetricsTracker]: """ @@ -915,14 +965,12 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - spec_augment: - The nn.Module instance (or similar object), used for training """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) + features = batch["inputs"] + # at entry, features is (N, T, C) + assert features.ndim == 3 + features = features.to(device) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -934,34 +982,27 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) - if num_copies > 1: - assert model.training - # will need the following for time-warping in nn.Module. - supervision_intervals = batch["supervisions"] - supervision_segments = torch.stack( - [ - supervision_intervals["sequence_idx"], - supervision_intervals["start_frame"], - supervision_intervals["num_frames"], - ], - dim=1, - ) # shape: (S, 3) + + if is_training: + # the num_copies thing is actually not very important any more, you can remove + # the assertion if it's a problem in future. (previously we used losses that + # required the different copies to be in sync on the time dimension, e.g. + # to use the same time warping; we don't do this any more.) + assert num_copies == 2 + batch_size = features.shape[0] + features = augmentation(features, feature_lens) else: - supervision_segments = None - spec_augment = None # disable spec-aug + assert num_copies == 1 + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( - x=feature, + x=features, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - spec_augment=spec_augment, - supervision_segments=supervision_segments, - time_warp_factor=80, # for specaug - num_copies=num_copies, aux_loss_scale=aux_loss_scale, ) @@ -1056,7 +1097,6 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, - spec_augment: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1083,8 +1123,6 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. - spec_augment: - The SpecAugment or similar instance used for CR-CTC. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1137,7 +1175,6 @@ def save_bad_model(suffix: str = ""): sp=sp, batch=batch, is_training=True, - spec_augment=spec_augment, aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), ) # summary stats @@ -1342,8 +1379,6 @@ def run(rank, world_size, args): assert params.use_ctc # for now, require CTC, we may remove this requirement later. - spec_augment = ExpAugment() - assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1499,7 +1534,6 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, sp=sp, params=params, - spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1527,7 +1561,6 @@ def remove_short_and_long_utt(c: Cut): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, - spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1596,7 +1629,6 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, - spec_augment: Optional[nn.Module] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1616,7 +1648,6 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/icefall/utils.py b/icefall/utils.py index e523a2e546..9997cdc9cb 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2315,15 +2315,19 @@ def time_warp( p: float = 0.9, time_warp_factor: Optional[int] = 80, supervision_segments: Optional[torch.Tensor] = None, + feature_lens: Optional[torch.Tensor] = None, ): - """Apply time warping on a batch of features""" + """Apply time warping on a batch of features + supervision_segments and feature_lens are two alternative ways of specifying the parts of the feature matrix to + warp, see the code for details. + """ if time_warp_factor is None or time_warp_factor < 1: return features assert ( len(features.shape) == 3 ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" features = features.clone() - if supervision_segments is None: + if supervision_segments is None and feature_lens is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): if random.random() > p: @@ -2332,7 +2336,8 @@ def time_warp( features[sequence_idx] = time_warp_impl( features[sequence_idx], factor=time_warp_factor ) - else: + elif supervision_segments is not None: + assert feature_lens is None # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: if random.random() > p: @@ -2343,4 +2348,14 @@ def time_warp( features[sequence_idx, start_frame:end_frame], factor=time_warp_factor ) + else: + for sequence_idx, num_frames in enumerate(feature_lens): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx, :num_frames] = time_warp_impl( + features[sequence_idx, :num_frames], factor=time_warp_factor + ) + + return features From 39a66470be767849356cd6f9c66555fc270a1c1b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Mar 2026 20:30:17 +0800 Subject: [PATCH 0976/1191] Bug fix --- egs/librispeech/ASR/zapformer/train.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 640bd147cb..219260ddf7 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -916,17 +916,14 @@ def augmentation( else, None. Note: these features will actually include any time-warping, based on the assumption that this needs to be kept in sync. """ - assert num_copies in [1, 3] - (batch_size, seq_len, num_channels) = x.shape - B = batch_size // num_copies - x = x.reshape(num_copies, B, seq_len, num_channels) + (batch_size, seq_len, num_channels) = features.shape do_time_warp = True if do_time_warp: with torch.amp.autocast('cuda', enabled=False): - x = time_warp( - x.to(torch.float), + features = time_warp( + features.to(torch.float), time_warp_factor=80, feature_lens=feature_lens, ) @@ -937,9 +934,9 @@ def augmentation( # chooses non-overlapping time regions to mask, but this is not so important # since the time warping (if used) was done independently on the two copies. spec_augment = ExpAugment() - x = spec_augment(x) + features = spec_augment(features) - return x + return features def compute_loss( From 315471b8ea99c3374e9799795a05b7b990d242d3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Mar 2026 20:54:48 +0800 Subject: [PATCH 0977/1191] fix import --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 219260ddf7..46953b7cd2 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -74,7 +74,7 @@ from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model2 import AsrModel +from model import AsrModel from optim import TransformedAdam from combined_scheduler import CombinedLRScheduler try: From ba0d20bae48e865b031140e2627082581da069f2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 Mar 2026 21:18:58 +0800 Subject: [PATCH 0978/1191] Bug fixes --- egs/librispeech/ASR/zapformer/model.py | 66 +++----------------------- egs/librispeech/ASR/zapformer/train.py | 32 ++----------- 2 files changed, 10 insertions(+), 88 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index b807876447..21fce88351 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -204,48 +204,6 @@ def forward_ctc( ) return ctc_loss - def forward_cr_ctc( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - targets: torch.Tensor, - target_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute CTC loss, with consistency regularization loss if we are in training mode. - Args: - encoder_out: - Encoder output, of shape (2 * N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (2 * N,). - targets: - Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed - to be un-padded and concatenated within 1 dimension. - """ - # Compute CTC loss - ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) - ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) - targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. - input_lengths=encoder_out_lens.long(), - target_lengths=target_lengths.long(), - reduction="sum", - ) - - # Compute consistency regularization loss - exchanged_targets = ctc_output.detach().chunk(2, dim=0) - exchanged_targets = torch.cat( - [exchanged_targets[1], exchanged_targets[0]], dim=0 - ) # exchange: [x1, x2] -> [x2, x1] - cr_loss = nn.functional.kl_div( - input=ctc_output, - target=exchanged_targets, - reduction="none", - log_target=True, - ) # (2 * N, T, C) - length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) - cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() - - return ctc_loss, cr_loss def forward_transducer( self, @@ -432,24 +390,14 @@ def forward( if self.use_ctc: targets = y.values - if not self.training: - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) - cr_loss = torch.empty(0) - else: - ctc_loss, cr_loss = self.forward_cr_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) else: ctc_loss = torch.empty(0) - cr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -461,4 +409,4 @@ def forward( else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 46953b7cd2..d947b49dc3 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -48,7 +48,6 @@ - transducer loss (default) - ctc loss - attention decoder loss - - cr-ctc loss (should use half the max-duration compared to regular ctc) """ @@ -500,13 +499,6 @@ def get_parser(): help="Scale for CTC loss.", ) - parser.add_argument( - "--cr-loss-scale", - type=float, - default=0.2, - help="Scale for consistency-regularization loss.", - ) - parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -897,24 +889,12 @@ def augmentation( features: Tensor, feature_lens: Tensor) -> Tensor: """ - Does augmentation; if need_unaugmented_features returns (augmented_features, unaugmented_features), - else (augmented_features, None) Args: - params: command-lines options - num_copies: the number of copies of the data in "feature", expected to be 3, consisting of - (noise_augmentation_copy1, noise_augmentation_copy2, no_noise_augmentation). - features: a Tensor of shape (batch_size, seq_len, num_channels), with batch_size - expected to be a multiple of num_copies, with 3 versions of the minibatch appended - with torch.cat((aug1, aug2, original), dim=0) + features: a Tensor of shape (batch_size, seq_len, num_channels) Returns: - (augmented_features, unaugmented_features). - - augmented_features: feature with SpecAug, of shape (2 * batch_size // 3, seq_len, num_channels) - unaugmented_features: if need_unaugmented_features, of shape (2 * batch_size // 3, seq_len, num_channels); - else, None. Note: these features will actually include any time-warping, based on the assumption - that this needs to be kept in sync. + augmented_features """ (batch_size, seq_len, num_channels) = features.shape @@ -993,7 +973,7 @@ def compute_loss( with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( x=features, x_lens=feature_lens, y=y, @@ -1020,8 +1000,6 @@ def warmup_schedule(scale, initial_factor): if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss - if num_copies > 1: - loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -1032,8 +1010,6 @@ def warmup_schedule(scale, initial_factor): with warnings.catch_warnings(): warnings.simplefilter("ignore") nframes = (feature_lens // params.subsampling_factor).sum().item() - if num_copies > 1: - nframes = nframes * (num_copies - 1) / num_copies # omit 1st copy info["frames"] = nframes # Note: We use reduction=sum while computing the loss. @@ -1043,8 +1019,6 @@ def warmup_schedule(scale, initial_factor): info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() - if num_copies > 1: - info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() From d2ad0bb650da389991ae25331f00eb11942cdf19 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 Mar 2026 12:12:29 +0800 Subject: [PATCH 0979/1191] Change defaults and test code in optim.py, will not affect our runs. --- egs/librispeech/ASR/zipformer/optim.py | 36 ++++++++++++++++---------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 47d2732255..317dc1eeec 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -495,13 +495,13 @@ def __init__( self, params, lr=1e-03, - beta1=0.999, - direct=0.0, # scale on bypass of momentum (beta1) + beta1=0.995, + direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, - wd=25, + wd=12, eps=1.0e-16, - scale_limits=(0.5, 2.0), + scale_limits=(1.0, 4.0), ): defaults = dict( @@ -905,13 +905,13 @@ def __init__( self, params, lr=1e-03, - beta1=0.999, - direct=0.0, # scale on bypass of momentum (beta1) + beta1=0.995, + direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, - wd=25, + wd=12, eps=1.0e-16, - scale_limits=(0.5, 2.0), + scale_limits=(1.0, 4.0), ): defaults = dict( lr=lr, @@ -1012,19 +1012,27 @@ def _test_transformed_adam(hidden_dim: int): for _ in range(20) ] - lr = 0.0006 + lr = 0.001 + # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the + # optimum parameters very exactly. Normally you want something more like the + # defaults of beta1=0.995 and direct=0.15 if test == 0: - optim = TransformedAdam(m.named_parameters(), lr=lr, wd=24) + optim = TransformedAdam(m.named_parameters(), lr=lr, direct=0.0, beta1=0.999) elif test == 1: - optim = SimpleTransformedAdam(m.parameters(), lr=lr, wd=24) + optim = SimpleTransformedAdam(m.parameters(), lr=lr, direct=0.0, beta1=0.999) num_epochs = 180 total_steps = num_epochs def lr_lambda(current_step): - # Cosine annealing - progress = current_step / total_steps - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + # a LR schedule similar to InterpCosineLRScheduler from combined_scheduler.py + progress = min(1, current_step / total_steps) + cos = math.cos(progress * math.pi / 2) + # the relatively small scale on cos means the linear cool-down phase + # is long/slow, as the loss of this easy task is dominated by + # parameter noise.. in practical scenarios we use larger scale on + # the cos term, e.g. as large as 0.66. + return 0.05 * cos + 0.95 * (cos ** 2) scheduler = LambdaLR(optim, lr_lambda) From 2f227b3dc65e4885fe6c2d202992cd1d0fe1cc56 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 Mar 2026 13:11:43 +0800 Subject: [PATCH 0980/1191] Move code to batched_rubik, rubik instead of optim.py --- .../ASR/zapformer/batched_rubik.py | 819 ++++++++++++++++++ egs/librispeech/ASR/zapformer/rubik.py | 470 ++++++++++ egs/librispeech/ASR/zapformer/train.py | 13 +- 3 files changed, 1300 insertions(+), 2 deletions(-) create mode 100644 egs/librispeech/ASR/zapformer/batched_rubik.py create mode 100644 egs/librispeech/ASR/zapformer/rubik.py diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py new file mode 100644 index 0000000000..206673dfc7 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -0,0 +1,819 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import math +import logging +import random +from collections import defaultdict +from torch.optim.lr_scheduler import LambdaLR + +from typing import Dict, List, Optional, Tuple, Union +import torch +from torch import Tensor +from torch.optim import Optimizer + + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + + +def compute_prod3(x): + assert x.ndim >= 2 + if x.shape[-2] <= x.shape[-1]: + x2 = torch.matmul(x, x.transpose(-2, -1)) + return torch.matmul(x2, x) + else: + x2 = torch.matmul(x.transpose(-2, -1), x) + return torch.matmul(x, x2) + +def compute_scaled_prod3(x): + # computes 3-way matrix power x^3 (x is treated as a batch of matrices) with a scaling such that (for each + # matrix in the batch) if all the singular values of the matrix are the same, the result will be identical to the input. + + rows, cols = x.shape[-2], x.shape[-1] + + eps = 1.0e-40 + x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps + x = x * (x_meansq * max(rows, cols)) ** (-1/3) + return compute_prod3(x) + + +def get_matrix_shape(shape): + shape = list(shape) + batch_size = shape[0] # batch size is 1st element of shape + shape = shape[1:] + def prod(l): + ans = l[0] + for n in l[1:]: + ans = ans * n + return ans + n = len(shape) + diffs = [ ] + for i in range(1, n): + prod1 = prod(shape[:i]) + prod2 = prod(shape[i:]) + diff = abs(prod1 - prod2) + diffs.append(diff) + min_diff = min(diffs) + for i in range(1, n): + if diffs[i-1] == min_diff: + return batch_size, prod(shape[:i]), prod(shape[i:]) + + +def cubic_decay_step(group, state, grad): + delta = grad + + lr = group["lr"] + eps = group["eps"] + step = state["step"] + beta_ceil = 1. - 1. / (10. + 0.2 * step) + beta1 = min(group["beta1"], beta_ceil) + beta2 = min(group["beta2"], beta_ceil) + direct = group["direct"] + cubic_decay_proportion = group["cubic_decay_proportion"] + linear_decay_proportion = 1. - cubic_decay_proportion + + min_scale, max_scale = group["scale_limits"] + + try: + stored_delta = state["delta"] + except KeyError as e: + assert step < 2 + # scalar. use conventional momentum. + stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["delta"] = stored_delta + + def min_sum_scale(x, y): + # returns the scale alpha such that (x + alpha y) is minimized. x and y have + # the same shape and the shape of alpha is (x.shape[0], 1, 1, ...). + assert x.ndim > 1 + dims = list(range(1, x.ndim)) + yy = (y ** 2).sum(dim=dims, keepdim=True) + xy = (y * x).sum(dim=dims, keepdim=True) + # sum square of x + alpha y is xx + alpha^2 yy + 2 alpha xy + # d/dalpha[that] = 2 alpha yy + 2 xy + # alpha = xy / yy + return -xy / (yy + eps) + + d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) + assert d.untyped_storage() is stored_delta.untyped_storage() + (batch_size, rows, cols) = d.shape + + if "row_stats" not in state: + state["row_stats"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) + state["direct_row_stats"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) + state["col_stats"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) + state["direct_col_stats"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) + + row_stats = state["row_stats"] + col_stats = state["col_stats"] + direct_row_stats = state["direct_row_stats"] + direct_col_stats = state["direct_col_stats"] + + delta = delta.reshape(*d.shape) + + d.add_(delta) # the scale used here doesn't matter as it all gets normalized. + d.mul_(1 - (linear_decay_proportion * (1 - beta1))) + + d2 = d ** 2 + + # we'll scale both before and after the cubing. + # the lines where we divide by sqrt of the mean are so we don't double + # count the scalar component of this. + row_scale = (row_stats + eps).sqrt() + col_scale = (col_stats + eps).sqrt() + row_col_scale = row_scale * col_scale + + d_norm1 = d / row_col_scale # this is the first of two steps of normalizing by these stats. + + prod3 = compute_scaled_prod3(d_norm1) + + alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + # we multiply prod3 by row_col_scale to "un-normalize". + # In the normal case where we're not limited by stability-of-update-concerns, + # the next line of code is equivalent to: + # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_proportion) + d.add_((prod3 * row_col_scale) * alpha) + + d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. + + d_norm1_sq = d_norm1 ** 2 + + # first update row_stats. + row_stats.mul_(beta2).add_((d_norm1 ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + + # d_norm1b means we've doing the second normalization but only by rows so far. + d_norm1b = d_norm1 / (row_stats + eps).sqrt() + + col_stats.mul_(beta2).add_((d_norm1b ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + + d_norm2 = d_norm1b / (col_stats + eps).sqrt() + + # do "immediate" normalization of total norm to make the overall scale of the update what + # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. + # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style + # accumulator with beta equal to beta1. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + + d_norm3 = d_norm2 * (assumed_scale / ((d_norm2 ** 2).mean(dim=(1, 2), keepdim=True) + eps).sqrt()) + + moving_update = d_norm3 + + if direct == 0.0: + return -lr * moving_update.reshape(*grad.shape) + + # row/col normalization of direct/bypass gradient "delta". + direct_row_stats.mul_(beta2).add_((delta ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + delta = delta / (direct_row_stats + eps).sqrt() + direct_col_stats.mul_(beta2).add_((delta ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + delta = delta / (direct_col_stats + eps).sqrt() + + ans = (-lr * (1-direct)) * moving_update + (-lr * direct) * delta + return ans.reshape(*grad.shape) + + +def scaling_step(group, param, state, grad): + lr = group["lr"] + wd = group["wd"] + + if grad.ndim >= 3 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): + delta = cubic_decay_step(group, state, grad) + else: + # biases and similar-shaped tensors + delta = adam_step(group, state, grad) + + try: + scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] + except: + shape = [ param.shape[0] ] + [1] * (param.ndim - 1) + scale = torch.ones(*shape, device=grad.device) + scale_grad_buf = torch.zeros(*shape, device=grad.device) + state["scale"] = scale + state["scale_grad_buffer"] = scale_grad_buf + + momentum = 0.95 + min_scale, max_scale = group["scale_limits"] + + dims = list(range(1, param.ndim)) + + scale_grad = (grad * param.detach()).sum(dim=dims, keepdim=True) + scale_grad_buf.mul_(momentum).add_(scale_grad) + + old_scale = scale.clone() + + scale.add_(scale_grad_buf.sign(), alpha=-lr) + scale.clamp_(min=min_scale, max=max_scale) + + scale_ratio = scale / old_scale + + delta_scale = (scale_ratio * (1 - (lr * wd) ** 2)) - 1 + return param * delta_scale + scale * delta + + +def adam_step(group, state, grad): + lr = group["lr"] + step = state["step"] + eps = group["eps"] + # just hardcode these. we only use this code for biases and scalars. + beta1 = 0.98 + beta2 = 0.98 + + try: + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + except KeyError as e: + assert step < 2 + exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + bias_correction2 = 1 - beta2 ** (step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = (exp_avg_sq + eps).sqrt() + + return -lr * (exp_avg / denom) + + + + +class BatchedRubik(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. + Must satisfy 0 < beta <= beta2 < 1. + betas: a list of decay constants for momentum on the parameter-change + scales: a list of scales corresponding to each of the betas, that we multiply + each momentum-update by. Implicitly there is also a beta=0, scale=1, + i.e. a non-decayed update. + scaling_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each non-scalar parameter tensor. If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + scale_decay: A constant similar to the weight_decay of AdamW, that applies on the scaling + factors, decaying them in log-space to scale_default. + scale_default: A constant that dictates the RMS value to which weight magnitudes decay. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. + eps: A general-purpose epsilon to prevent division by zero + """ + def __init__( + self, + params, + lr=1e-03, + beta1=0.995, + direct=0.15, # scale on bypass of momentum (beta1) + cubic_decay_proportion=0.8, + beta2=0.98, + wd=12, + eps=1.0e-16, + scale_limits=(1.0, 4.0), + ): + + defaults = dict( + lr=lr, + beta1=beta1, + direct=direct, + cubic_decay_proportion=cubic_decay_proportion, + beta2=beta2, + eps=eps, + wd=wd, + scale_limits=scale_limits, + ) + + param_groups, parameters_names = self._get_names_of_parameters(params) + super(BatchedRubik, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way TransformedAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = TransformedAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = TransformedAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = TransformedAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = TransformedAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + if "named_params" in cur_group: + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + else: + assert "params" in cur_group + name_list = ["foo" for _ in cur_group["params"]] + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + + + def __setstate__(self, state): + super(TransformedAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + + for p, state, _names in batches: + grad = p.grad + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + if p.numel() == p.shape[0]: + p += adam_step(group, state, grad) + else: + p += scaling_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 + + + return loss + + + +def _test_batched_rubik(hidden_dim: int): + import timeit + + from scaling import OrthogonalLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_batched_rubik") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + torch.random.manual_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + if True: + Linear = torch.nn.Linear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + lr = 0.001 + # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the + # optimum parameters very exactly. Normally you want something more like the + # defaults of beta1=0.995 and direct=0.15 + optim = BatchedRubik(m.named_parameters(), lr=lr, direct=0.0, beta1=0.999) + + num_epochs = 180 + + total_steps = num_epochs + def lr_lambda(current_step): + # a LR schedule similar to InterpCosineLRScheduler from combined_scheduler.py + progress = min(1, current_step / total_steps) + cos = math.cos(progress * math.pi / 2) + # the relatively small scale on cos means the linear cool-down phase + # is long/slow, as the loss of this easy task is dominated by + # parameter noise.. in practical scenarios we use larger scale on + # the cos term, e.g. as large as 0.66. + return 0.05 * cos + 0.95 * (cos ** 2) + + scheduler = LambdaLR(optim, lr_lambda) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + # if epoch == 100 and test in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + #scheduler.step_batch() + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[4].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[2].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[4].bias**2).mean().sqrt().item() + + lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step() # step once per epoch + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Time taken: {stop - start}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +def _test_muon(hidden_dim: int): + import timeit + + from muon import Muon + from scaling import OrthogonalLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_muon") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + if True: + fix_random_seed(42) + Linear = torch.nn.Linear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + OrthogonalLinear(hidden_dim, hidden_dim, bias=True, + in_groups=2, group_size=hidden_dim//4), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + optim = Muon(m.parameters(), + lr=0.5e-03, + wd=12.0) + + num_epochs = 180 + # hardcode batches per epoch for now. + total_steps = num_epochs + constant_fraction = 0.25 + + def lr_lambda(current_step): + progress = current_step / total_steps + if progress < constant_fraction: + return 1.0 + else: + return (1.0 - progress) / (1.0 - constant_fraction) + + scheduler = LambdaLR(optim, lr_lambda) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(num_epochs): + scheduler.step() + + # if epoch == 100 and test in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[1].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[3].weight**2).mean().sqrt().item() + norm4 = '%.2e' % (m[5].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[3].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[5].bias**2).mean().sqrt().item() + + lr = scheduler.get_last_lr()[0] + logging.info( + f"Test muon, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) + loss.log().backward() + optim.step() + optim.zero_grad() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Muon: time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +def _test_compute_scaled_prod3(): + x = torch.randn(3, 16, 32) + _U, _S, V = torch.linalg.svd(x, full_matrices=False) + W = V * torch.randn(3, 1, 1) + # so now all the singular values of x will be identical (but arbitrary) + + X = compute_scaled_prod3(W) + #print("X = ", X[0]) + #print("W = ", W[0]) + assert torch.allclose(W, X, atol=1.0e-03) + # but the result won't be identical to the input if the singular values are not all identical. + assert not torch.allclose(x, compute_scaled_prod3(x), atol=1.0e-03) + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_compute_scaled_prod3() + _test_batched_rubik(hidden_dim) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py new file mode 100644 index 0000000000..2424467193 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -0,0 +1,470 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import math +import logging +import random +from collections import defaultdict +from torch.optim.lr_scheduler import LambdaLR + +from typing import Dict, List, Optional, Tuple, Union +import torch +from torch import Tensor +from torch.optim import Optimizer + + +def compute_prod3(x): + assert x.ndim >= 2 + if x.shape[-2] <= x.shape[-1]: + x2 = torch.matmul(x, x.transpose(-2, -1)) + return torch.matmul(x2, x) + else: + x2 = torch.matmul(x.transpose(-2, -1), x) + return torch.matmul(x, x2) + +def compute_scaled_prod3(x): + # computes 3-way matrix power x^3 (x is treated as a batch of matrices) with a scaling such that (for each + # matrix in the batch) if all the singular values of the matrix are the same, the result will be identical to the input. + + rows, cols = x.shape[-2], x.shape[-1] + + eps = 1.0e-40 + x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps + x = x * (x_meansq * max(rows, cols)) ** (-1/3) + return compute_prod3(x) + + +def get_matrix_shape(shape): + shape = list(shape) + def prod(l): + ans = l[0] + for n in l[1:]: + ans = ans * n + return ans + n = len(shape) + diffs = [ ] + for i in range(1, n): + prod1 = prod(shape[:i]) + prod2 = prod(shape[i:]) + diff = abs(prod1 - prod2) + diffs.append(diff) + min_diff = min(diffs) + for i in range(1, n): + if diffs[i-1] == min_diff: + return prod(shape[:i]), prod(shape[i:]) + assert False, shape + + +def cubic_decay_step(group, state, grad): + delta = grad + + lr = group["lr"] + eps = group["eps"] + step = state["step"] + beta_ceil = 1. - 1. / (10. + 0.2 * step) + beta1 = min(group["beta1"], beta_ceil) + beta2 = min(group["beta2"], beta_ceil) + direct = group["direct"] + cubic_decay_proportion = group["cubic_decay_proportion"] + linear_decay_proportion = 1. - cubic_decay_proportion + + min_scale, max_scale = group["scale_limits"] + + try: + stored_delta = state["delta"] + except KeyError as e: + assert step < 2 + # scalar. use conventional momentum. + stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["delta"] = stored_delta + + def min_sum_scale(x, y): + # returns the scale alpha such that (x + alpha y) is minimized; x and + # y each have 2 dimensions. + return -(x * y).sum() / ((y ** 2).sum() + eps) + + d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) + assert d.untyped_storage() is stored_delta.untyped_storage() + (rows, cols) = d.shape + + if "row_stats" not in state: + state["row_stats"] = torch.ones(rows, 1, device=d.device, dtype=d.dtype) + state["direct_row_stats"] = torch.ones(rows, 1, device=d.device, dtype=d.dtype) + state["col_stats"] = torch.ones(1, cols, device=d.device, dtype=d.dtype) + state["direct_col_stats"] = torch.ones(1, cols, device=d.device, dtype=d.dtype) + + row_stats = state["row_stats"] + col_stats = state["col_stats"] + direct_row_stats = state["direct_row_stats"] + direct_col_stats = state["direct_col_stats"] + + delta = delta.reshape(*d.shape) + + d.add_(delta) # the scale used here doesn't matter as it all gets normalized. + d.mul_(1 - (linear_decay_proportion * (1 - beta1))) + + d2 = d ** 2 + + # we'll scale both before and after the cubing. + # the lines where we divide by sqrt of the mean are so we don't double + # count the scalar component of this. + row_scale = (row_stats + eps).sqrt() + col_scale = (col_stats + eps).sqrt() + row_col_scale = row_scale * col_scale + + d_norm1 = d / row_col_scale # this is the first of two steps of normalizing by these stats. + + prod3 = compute_scaled_prod3(d_norm1) + + alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + # we multiply prod3 by row_col_scale to "un-normalize". + # In the normal case where we're not limited by stability-of-update-concerns, + # the next line of code is equivalent to: + # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_proportion) + d.add_((prod3 * row_col_scale) * alpha) + + d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. + + d_norm1_sq = d_norm1 ** 2 + + # first update row_stats. + row_stats.mul_(beta2).add_((d_norm1 ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + + # d_norm1b means we've doing the second normalization but only by rows so far. + d_norm1b = d_norm1 / (row_stats + eps).sqrt() + + col_stats.mul_(beta2).add_((d_norm1b ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) + + d_norm2 = d_norm1b / (col_stats + eps).sqrt() + + # do "immediate" normalization of total norm to make the overall scale of the update what + # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. + # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style + # accumulator with beta equal to beta1. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + + d_norm3 = d_norm2 * (assumed_scale / ((d_norm2 ** 2).mean() + eps) .sqrt()) + + moving_update = d_norm3 + + if direct == 0.0: + return -lr * moving_update.reshape(*grad.shape) + + # row/col normalization of direct/bypass gradient "delta". + direct_row_stats.mul_(beta2).add_((delta ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + delta = delta / (direct_row_stats + eps).sqrt() + direct_col_stats.mul_(beta2).add_((delta ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) + delta = delta / (direct_col_stats + eps).sqrt() + + ans = (-lr * (1-direct)) * moving_update + (-lr * direct) * delta + return ans.reshape(*grad.shape) + + +def scaling_step(group, param, state, grad): + lr = group["lr"] + wd = group["wd"] + + if grad.ndim >= 2 and grad.numel() != max(grad.shape): + delta = cubic_decay_step(group, state, grad) + else: + # biases and similar-shaped tensors + delta = adam_step(group, state, grad) + + try: + scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] + except: + scale = torch.ones(1, device=grad.device) + scale_grad_buf = torch.zeros(1, device=grad.device) + state["scale"] = scale + state["scale_grad_buffer"] = scale_grad_buf + + momentum = 0.95 + min_scale, max_scale = group["scale_limits"] + + + scale_grad = (grad * param.detach()).sum() + scale_grad_buf.mul_(momentum).add_(scale_grad) + + old_scale = scale.clone() + + scale.add_(scale_grad_buf.sign(), alpha=-lr) + scale.clamp_(min=min_scale, max=max_scale) + + scale_ratio = scale / old_scale + + delta_scale = (scale_ratio * (1 - (lr * wd) ** 2)) - 1 + return param * delta_scale + scale * delta + + +def adam_step(group, state, grad): + lr = group["lr"] + step = state["step"] + eps = group["eps"] + # just hardcode these. we only use this code for biases and scalars. + beta1 = 0.98 + beta2 = 0.98 + + try: + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + except KeyError as e: + assert step < 2 + exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + bias_correction2 = 1 - beta2 ** (step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = (exp_avg_sq + eps).sqrt() + + return -lr * (exp_avg / denom) + + + +class Rubik(Optimizer): + """ + Version of TransformedAdam that doesn't do the batching or gradient clipping (may be easier to integrate + into other frameworks). + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses). + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. + Must satisfy 0 < beta <= beta2 < 1. + betas: a list of decay constants for momentum on the parameter-change + scales: a list of scales corresponding to each of the betas, that we multiply + each momentum-update by. Implicitly there is also a beta=0, scale=1, + i.e. a non-decayed update. + """ + def __init__( + self, + params, + lr=1e-03, + beta1=0.995, + direct=0.15, # scale on bypass of momentum (beta1) + cubic_decay_proportion=0.8, + beta2=0.98, + wd=12, + eps=1.0e-16, + scale_limits=(1.0, 4.0), + ): + defaults = dict( + lr=lr, + beta1=beta1, + direct=direct, + cubic_decay_proportion=cubic_decay_proportion, + beta2=beta2, + eps=eps, + wd=wd, + scale_limits=scale_limits, + ) + super().__init__(params, defaults) + + + def __setstate__(self, state): + super(TransformedAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group in self.param_groups: + + for p in group['params']: + state = self.state[p] + grad = p.grad + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + def u(x): + return x.unsqueeze(0) + + if p.numel() == 1: + p += adam_step(group, state, grad) + else: + p += scaling_step(group, u(p.detach()), state, u(grad))[0] + + state["step"] = cur_step + 1 + + return loss + + + +def _test_rubik(hidden_dim: int): + import timeit + + E = 100 + B = 4 + T = 2 + logging.info("in test_rubik") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + torch.random.manual_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + if True: + Linear = torch.nn.Linear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + lr = 0.001 + # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the + # optimum parameters very exactly. Normally you want something more like the + # defaults of beta1=0.995 and direct=0.15 + optim = Rubik(m.parameters(), lr=lr, direct=0.0, beta1=0.999) + + num_epochs = 180 + + total_steps = num_epochs + def lr_lambda(current_step): + # a LR schedule similar to InterpCosineLRScheduler from combined_scheduler.py + progress = min(1, current_step / total_steps) + cos = math.cos(progress * math.pi / 2) + # the relatively small scale on cos means the linear cool-down phase + # is long/slow, as the loss of this easy task is dominated by + # parameter noise.. in practical scenarios we use larger scale on + # the cos term, e.g. as large as 0.66. + return 0.05 * cos + 0.95 * (cos ** 2) + + scheduler = LambdaLR(optim, lr_lambda) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + # if epoch == 100 and test in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + #scheduler.step_batch() + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[4].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[2].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[4].bias**2).mean().sqrt().item() + + lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step() # step once per epoch + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Time taken: {stop - start}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +def _test_compute_scaled_prod3(): + x = torch.randn(16, 32) + _U, _S, V = torch.linalg.svd(x, full_matrices=False) + W = V * torch.randn(1, 1) + # so now all the singular values of x will be identical (but arbitrary) + + X = compute_scaled_prod3(W) + #print("X = ", X[0]) + #print("W = ", W[0]) + assert torch.allclose(W, X, atol=1.0e-03) + # but the result won't be identical to the input if the singular values are not all identical. + assert not torch.allclose(x, compute_scaled_prod3(x), atol=1.0e-03) + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_compute_scaled_prod3() + _test_rubik(hidden_dim) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index c0f89c435e..c4994f7c93 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -74,7 +74,16 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel -from optim import TransformedAdam +# the try-pass blocks around imports are to reduce the chance of failures when running multiple code +# versions in parallel; later, these can be removed. +try: + from batched_rubik import BatchedRubik as Rubik + # could also have done: + # from rubik import Rubik +except: + pass + + from combined_scheduler import CombinedLRScheduler try: from combined_scheduler import InterpCosineLRScheduler @@ -1365,7 +1374,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = TransformedAdam( + optimizer = Rubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, direct=0.15, From 096b9141f23d8a5e15999f4302eac2b423a5d9ae Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 19 Mar 2026 18:52:45 +0800 Subject: [PATCH 0981/1191] add commonvoice dataset --- .../ASR/zapformer/asr_datamodule.py | 41 ++++++++++++- egs/librispeech/ASR/zapformer/decode.py | 57 +++++++++++++++---- .../ASR/zapformer/streaming_decode.py | 34 +++++++++-- egs/librispeech/ASR/zapformer/train.py | 34 +++++++---- 4 files changed, 137 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 853c14c7c7..2b6d7fe132 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -98,7 +98,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--full-libri", type=str2bool, default=True, - help="""When enabled, use 960h LibriSpeech; and 10000 hour GigaSpeech if --use-gigs. Otherwise, use 100h and if applicable 250h subsets.""", + help="""When enabled, use 960h LibriSpeech; and 10000 hour GigaSpeech if --use-giga. Otherwise, use 100h and if applicable 250h subsets.""", ) group.add_argument( "--mini-libri", @@ -224,7 +224,12 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="If set to True, use gigaspeech in addition to librispeech. See also --libri-copies." ) - + parser.add_argument( + "--use-cv", + type=str2bool, + default=False, + help="If set to True, use CommonVoice in addition to librispeech. See also --libri-copies." + ) def train_dataloaders( self, @@ -540,3 +545,35 @@ def dev_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" logging.info(f"About to get DEV cuts from {f}") return load_manifest_lazy(f) + + +class CommonVoice: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cv22-en_cuts_train.jsonl.gz + - cv22-en_cuts_dev.jsonl.gz + - cv22-en_cuts_test.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get train cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_train.jsonl.gz" + ) + + def dev_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get dev cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_dev.jsonl.gz" + ) + + def test_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get test cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_test.jsonl.gz" + ) \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 841ff0142b..90c7a4a309 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -98,6 +98,7 @@ import logging import math import os +import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -106,7 +107,7 @@ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeech, GigaSpeech, AsrDataModule +from asr_datamodule import CommonVoice, LibriSpeech, GigaSpeech, AsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -175,7 +176,7 @@ ) -def asr_text_post_processing(text: str) -> str: # only used for gigaspeech +def giga_asr_text_post_processing(text: str) -> str: # only used for gigaspeech # 1. convert to uppercase text = text.upper() @@ -192,13 +193,27 @@ def asr_text_post_processing(text: str) -> str: # only used for gigaspeech return " ".join(remaining_words) -def post_processing( + +def giga_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = giga_asr_text_post_processing(" ".join(ref)).split() + new_hyp = giga_asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def cv_post_processing( results: List[Tuple[str, List[str], List[str]]], ) -> List[Tuple[str, List[str], List[str]]]: + def normalize(text): + return re.sub(r'[^\w\s]', '', text).upper() new_results = [] for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_ref = normalize(" ".join(ref)).split() + new_hyp = normalize(" ".join(hyp)).split() new_results.append((key, new_ref, new_hyp)) return new_results @@ -445,6 +460,13 @@ def get_parser(): help="""If True, decode gigaspeech in addition to librispeech test sets.""", ) + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + add_model_arguments(parser) return parser @@ -799,8 +821,10 @@ def save_asr_output( recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - if params.giga: - results = post_processing(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) store_transcripts(filename=recogs_filename, texts=results) @@ -817,8 +841,10 @@ def save_wer_results( """ test_set_wers = dict() for key, results in results_dict.items(): - if params.giga: - results = post_processing(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -1138,8 +1164,17 @@ def main(): dev_cuts = gigaspeech.dev_cuts() giga_test_dl = asr_datamodule.test_dataloaders(test_cuts) giga_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) - test_sets += ["dev", "test"] - test_dl += [giga_test_dl, giga_dev_dl] + test_sets += ["giga-dev", "giga-test"] + test_dl += [giga_dev_dl, giga_test_dl] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + cv_test_dl = asr_datamodule.test_dataloaders(test_cuts) + cv_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["cv-dev", "cv-test"] + test_dl += [cv_dev_dl, cv_test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py index 5c480e117e..a04ed04adf 100755 --- a/egs/librispeech/ASR/zapformer/streaming_decode.py +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -40,7 +40,8 @@ import numpy as np import sentencepiece as spm import torch -from asr_datamodule import LibriSpeech, GigaSpeech, AsrDataModule +from asr_datamodule import CommonVoice, LibriSpeech, GigaSpeech, AsrDataModule +from decode import cv_post_processing, giga_post_processing from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet, set_caching_enabled @@ -209,6 +210,13 @@ def get_parser(): help="""If True, decode gigaspeech in addition to librispeech test sets.""", ) + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + add_model_arguments(parser) return parser @@ -647,9 +655,14 @@ def save_asr_output( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) store_transcripts(filename=recogs_filename, texts=results) logging.info(f"The transcripts are stored in {recogs_filename}") + def save_wer_results( params: AttributeDict, test_set_name: str, @@ -660,6 +673,10 @@ def save_wer_results( """ test_set_wers = dict() for key, results in results_dict.items(): + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -857,10 +874,17 @@ def main(): if args.giga: gigaspeech = GigaSpeech(args.manifest_dir) - test_cuts = gigaspeech.test_cuts() - dev_cuts = gigaspeech.dev_cuts() - test_sets += ["dev", "test"] - test_cuts += [dev_cuts, test_cuts] + giga_test_cuts = gigaspeech.test_cuts() + giga_dev_cuts = gigaspeech.dev_cuts() + test_sets += ["giga-dev", "giga-test"] + test_cuts += [giga_dev_cuts, giga_test_cuts] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + cv_test_cuts = commonvoice.test_cuts() + cv_dev_cuts = commonvoice.dev_cuts() + test_sets += ["cv-dev", "cv-test"] + test_cuts += [cv_dev_cuts, cv_test_cuts] for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index c4994f7c93..5db8a1cb93 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -66,7 +66,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import AsrDataModule, LibriSpeech, GigaSpeech +from asr_datamodule import AsrDataModule, CommonVoice, LibriSpeech, GigaSpeech from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner @@ -1419,6 +1419,7 @@ def lr_lambda(current_step): asr_datamodule = AsrDataModule(args) librispeech = LibriSpeech(args.manifest_dir) gigaspeech = GigaSpeech(args.manifest_dir) # gigaspeech will only be used if the --use-giga=True option is set + commonvoice = CommonVoice(args.manifest_dir) # commonvoice will only be used if the --use-cv=True option is set if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() @@ -1436,19 +1437,30 @@ def lr_lambda(current_step): train_cuts = librispeech.train_clean_100_cuts() train_cuts_len = 100.0 * 3 # 100 hours times 3 for speed augmentation - if params.use_giga: - if params.full_libri: - gigaspeech_cuts = gigaspeech.train_XL_cuts() - gigaspeech_cuts_len = 10000.0 - else: - gigaspeech_cuts = gigaspeech.train_S_cuts() # e.g. for debugging - gigaspeech_cuts_len = 250.0 - + if params.use_giga or params.use_cv: if params.libri_copies > 1: train_cuts = train_cuts.repeat(params.libri_copies) train_cuts_len = train_cuts_len * params.libri_copies - datasets_and_weights = [ (train_cuts, train_cuts_len), - (gigaspeech_cuts, gigaspeech_cuts_len) ] + datasets_and_weights = [(train_cuts, train_cuts_len)] + + if params.use_giga: + if params.full_libri: + gigaspeech_cuts = gigaspeech.train_XL_cuts() + gigaspeech_cuts_len = 10000.0 + else: + gigaspeech_cuts = gigaspeech.train_S_cuts() # e.g. for debugging + gigaspeech_cuts_len = 250.0 + datasets_and_weights.append((gigaspeech_cuts, gigaspeech_cuts_len)) + + if params.use_cv: + import re + def normalize_text(c): + c.supervisions[0].text = re.sub(r'[^\w\s]', '', c.supervisions[0].text).upper() + return c + commonvoice_cuts = commonvoice.train_cuts().map(normalize_text) + commonvoice_cuts_len = 2600.0 + datasets_and_weights.append((commonvoice_cuts, commonvoice_cuts_len)) + cuts, weights = zip(*datasets_and_weights) train_cuts = CutSet.mux(*cuts, weights=weights) From 27424b4f582894c88b2c96b61165d158ad42c380 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Mar 2026 10:45:05 +0800 Subject: [PATCH 0982/1191] Replace FftConv with BasisConv --- egs/librispeech/ASR/zipformer/zipformer.py | 194 ++++++++++++++++++++- 1 file changed, 192 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 44f554683a..277dea01c8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1594,6 +1594,9 @@ def round_up_to_power_of_two(x): return x + + + class FftConv(nn.Module): def __init__(self, num_channels: int, @@ -1647,6 +1650,165 @@ def forward(self, return x +# convolution where we convolve with a combination of basis functions, the basis functions +# being based on linear interpolation in Fourier space-- in effect, each pair of basis functions +# corresponds to the real and imaginary coefficients for one triangular bin in Fourier space; +# in the time domain the triangular bin becomes a sinc^2 function and the frequency offset +# is just a complex exponential of which the real and imaginary coefficients give us sines and +# cosines. +def get_basis_funcs(seq_len: int, + num_freqs: int, + **kwargs +): + """ + seq_len: the sequence length to which the basis functions are truncated; this is expected to + be even + num_freqs: the number of frequencies; the number of basis functions will be 2 * num_freqs, + and note that the first pair of basis functions are special, because they are the + (zero-freq; nyquist-freq) ones. + kwargs: can be used for device + + Returns: + basis functions of shape: (2 * num_freqs, seq_len) + """ + assert seq_len % 2 == 0 + t = torch.cat((torch.arange(seq_len // 2, **kwargs), + torch.arange(-seq_len // 2, 0, **kwargs)), dim=0) # e.g. tensor([ 0, 1, 2, 3, -4, -3, -2, -1]) + # the second half of the "t" values are interpreted as the "negative half" of the time range-- + # the time range representing t values from -seq_len // 2 to seq_len // 2 - 1. + # The way we use this will be to convolve it with a signal of size seq_len // 2 that + # has been padded with zeroes of length seq_len // 2, and we want the result to be as if we padded with the basis + # functions from -infinity to infinity. + + + scaled_t = t * math.pi / num_freqs + + # "freqs" are the t values multiplied by the basis frequencies + t_freqs = scaled_t * torch.arange(num_freqs + 1, **kwargs).unsqueeze(-1) + # t_freqs: (num_freqs + 1, seq_len) + + # it's a sinc-squared envelope, as the frequency domain envelope is a + # triangular, not a rectangular, function. the factor of 0.5 comes + # from the math + sinc_arg = 0.5 * scaled_t + envelope = torch.where(sinc_arg != 0.0, sinc_arg.sin() / sinc_arg, torch.ones_like(sinc_arg)) ** 2 + + + cos, sin = t_freqs.cos() * envelope, t_freqs.sin() * envelope + #plt.plot(envelope) + + # the factor of 0.5 is because the other freqs would get "counted twice" due + # to having two symmetric versions, the freqs at zero and the nyquist only have + # one copy. This ensures that if we give a coeff of all ones on all the + # cos terms, we get (a scaled version of) the delta function. + sin[0] = 0.5 * cos[-1] + cos[0] = 0.5 * cos[0] + # the sin coefficient of freq 0 and nyquist gives us nothing, so we use the cos + # at the nyquist in this position. + cos = cos[:num_freqs] + sin = sin[:num_freqs] + #scale = num_freqs ** -0.5 # scale to make the funcs have a value around 1. + #cos = cos * scale + #sin = sin * scale + + basis = torch.cat((cos, sin), dim=0) + # basis: (2 * num_freqs, seq_len) + + #for i in range(num_freqs + 1): + # plt.plot(cos[i]) + # plt.plot(sin[i]) + # plt.show() + return basis + + +def fourier_conv(x: Tensor, y: Tensor): + # fourier based convolution of x and y, returns + # something with the same sequence length as the shorter of + # the two. + # x, y: (seq_len, [1 or batch_size], num_channels) + T = max(x.shape[0], y.shape[0]) + T_out = min(x.shape[0], y.shape[0]) + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + y = y.to(torch.float) + X = torch.fft.rfft(x, dim=0, n=T) + Y = torch.fft.rfft(y, dim=0, n=T) + return torch.fft.irfft(X * Y, dim=0, n=T)[:T_out] + + +class FourierConv(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return fourier_conv(x, y) + + @staticmethod + def backward(ctx, ans_grad): + # we could probably do a bit better than this by doing it manually + x, y = ctx.saved_tensors + with torch.enable_grad(): + x = x.detach() + y = y.detach() + x.requires_grad = True + y.requires_grad = True + ans = fourier_conv(x, y) + ans.backward(gradient=ans_grad) + return x.grad, y.grad + + + + +class BasisConv(nn.Module): + def __init__(self, + num_channels: int, + num_freqs: int, + params_per_channel: int): + super().__init__() + self.weight_proj = nn.Linear(params_per_channel, 2 * num_freqs) + + self.weight = nn.Parameter(0.05 * torch.randn(num_channels, + params_per_channel)) + + + def forward(self, + x: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + + + # round seq_len to a multiple of "round" to help ensure the FFT dimension + # has plenty of powers of two; this will tend to make it more efficient. + round = min(16, round_up_to_power_of_two(seq_len)) + seq_len_rounded = round * ((seq_len + round - 1) // round) + + # to ensure the answer is the same regardless of the amount of padding, we + # pad the sequence to at least twice its initial length for purposes of + # the FFT-based convolution. Because we will view the basis functions + # as going from t=-seq_len_rounded to t=seq_len_rounded - 1, this will + # ensure that we never see "wrap-around" effects. + T = 2 * seq_len_rounded + + num_freqs = self.weight_proj.weight.shape[0] // 2 + basis_funcs = get_basis_funcs(T, num_freqs, device=x.device) + # basis_funcs: (2 * num_freqs, T) + + scale = num_freqs ** -0.5 + + weight = scale * self.weight_proj(self.weight) + # weight: (num_channels, 2 * num_freqs) + channel_funcs = torch.matmul(weight, basis_funcs) + # channel_funcs: (num_channels, T) + + + # channel_funcs: (num_channels, T) + channel_funcs = channel_funcs.t().unsqueeze(1) + # channel_funcs: (T, 1, num_channels) + + return FourierConv.apply(channel_funcs, x) + + + + class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer2 model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py @@ -1684,9 +1846,10 @@ def __init__( self.activation2 = Identity() # for diagnostics - self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) if not causal: - self.depthwise_conv = FftConv(bottleneck_dim, kernel_size) + self.depthwise_conv = BasisConv(bottleneck_dim, + num_freqs=kernel_size*2, + params_per_channel=kernel_size) else: self.depthwise_conv = nn.Conv1d( in_channels=bottleneck_dim, @@ -1957,10 +2120,37 @@ def _test_zipformer_streaming(): logging.info("Passed") + +def _test_basis_conv(): + num_channels = 11 + f = BasisConv(num_channels=num_channels, + num_freqs=4, + params_per_channel=2) + + seq_len = 100 + subseq_len = 10 # will help visualize the effect + batch_size = 2 + x = torch.cat((torch.randn(subseq_len, batch_size, num_channels), + torch.zeros(seq_len - subseq_len, batch_size, num_channels)), + dim=0) + + y = f(x) + + #plt.plot(x[:, 0, 0].detach()) + #plt.plot(y[:, 0, 0].detach()) + #plt.show() + + + def rms(a): + return (a**2).mean().item() + print(f"rms(x)={rms(x)}, rms(y)={rms(y)}") + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_basis_conv() _test_zipformer_main(False) _test_zipformer_main(True) _test_zipformer_streaming() From c3921f926dedd0b3cdd27d01e21e32ad1d750b2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Mar 2026 20:49:36 +0800 Subject: [PATCH 0983/1191] Fix wrong class names in super() --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/rubik.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 206673dfc7..42471da4b2 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -520,7 +520,7 @@ def _get_names_of_parameters( def __setstate__(self, state): - super(TransformedAdam, self).__setstate__(state) + super(BatchedRubik, self).__setstate__(state) @torch.no_grad() def step(self, closure=None): diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 2424467193..cbb62a2ca7 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -285,7 +285,7 @@ def __init__( def __setstate__(self, state): - super(TransformedAdam, self).__setstate__(state) + super(Rubik, self).__setstate__(state) @torch.no_grad() def step(self, closure=None): From 957b23b812c8f39430bc1f06ea9cfb9dc8ce1f2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Mar 2026 21:26:53 +0800 Subject: [PATCH 0984/1191] Implement WeightedMean to bypass convolutions; this breakds streaming test. --- egs/librispeech/ASR/zipformer/zipformer.py | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 277dea01c8..af687350af 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1757,7 +1757,41 @@ def backward(ctx, ans_grad): return x.grad, y.grad +class WeightedMean(nn.Module): + def __init__(self, + num_channels: int, + causal: bool = False): + super().__init__() + self.causal = causal + self.weights = nn.Parameter(0.1 * torch.randn(num_channels)) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute weighted mean. + x: (time, batch, channel) + src_key_padding_mask: (batch, time), True for masked positions + + Returned shape: (time, batch, channel) if causal else (batch, channel) + """ + T = x.shape[0] + if self.causal: + num_frames = torch.arange(1, T + 1, device=x.device) + x_cumsum = torch.cumsum(x, dim=0) + return x_cumsum * num_frames[:, None, None] * self.weights + + + # assume x already masked, if mask is in use. + if src_key_padding_mask is not None: + num_frames = src_key_padding_mask.logical_not().to(torch.float).sum(dim=1) + num_frames = num_frames.unsqueeze(-1).to(torch.float) + # num_frames: (batch_size, 1) + return x.mean(dim=0) * (T / num_frames) * self.weights + else: + return x.mean(dim=0) * self.weights + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) class BasisConv(nn.Module): def __init__(self, @@ -1846,6 +1880,7 @@ def __init__( self.activation2 = Identity() # for diagnostics + if not causal: self.depthwise_conv = BasisConv(bottleneck_dim, num_freqs=kernel_size*2, @@ -1862,6 +1897,10 @@ def __init__( self.left_pad = kernel_size - 1 self.depthwise_conv.lr_scale = 0.66 + # add average-of-all-frames to the "convolution."; it has extra power vs the convolution + # because the num frames differs between utterances. + self.weighted_mean = WeightedMean(bottleneck_dim, + causal=causal) self.out_proj = ActivationAndLinear( bottleneck_dim, @@ -1905,6 +1944,8 @@ def forward( if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + wm = self.weighted_mean(x) if self.causal: # Not support exporting a model for simulated streaming decoding assert not torch.jit.is_scripting() and not torch.jit.is_tracing() @@ -1916,6 +1957,8 @@ def forward( x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) else: x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) + x = x + wm # Add in the weighted-mean to the convolution; this adds extra power + # because the utterances differ in length. x = x * y x = self.out_proj(x) # (time, batch, channels) From a429e4899b679438ebc0fa12fb4e2261626a2ae0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Mar 2026 23:19:55 +0800 Subject: [PATCH 0985/1191] Bug fix re src_key_padding_mask, use it. --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index af687350af..b7c83cc3ea 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1945,7 +1945,7 @@ def forward( x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - wm = self.weighted_mean(x) + wm = self.weighted_mean(x, src_key_padding_mask) if self.causal: # Not support exporting a model for simulated streaming decoding assert not torch.jit.is_scripting() and not torch.jit.is_tracing() From fd8147c746d1f6a21fef424a55a9709420c67996 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Mar 2026 11:58:41 +0800 Subject: [PATCH 0986/1191] Use 4, not 2, copies of the data. --- egs/librispeech/ASR/zapformer/multicopy_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/multicopy_dataset.py b/egs/librispeech/ASR/zapformer/multicopy_dataset.py index ffac8b04af..f445adbe1f 100755 --- a/egs/librispeech/ASR/zapformer/multicopy_dataset.py +++ b/egs/librispeech/ASR/zapformer/multicopy_dataset.py @@ -120,13 +120,13 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]] if self.cut_transforms: orig_cuts = cuts - cuts = cuts.repeat(times=2) + cuts = cuts.repeat(times=4) for tnfm in self.cut_transforms: cuts = tnfm(cuts) #cuts = orig_cuts + cuts - num_copies = 2 + num_copies = 4 else: num_copies = 1 From bc7d0b6af07a6e0523276bbcaa04ca3191978736 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Mar 2026 12:20:13 +0800 Subject: [PATCH 0987/1191] Fix assertion. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5db8a1cb93..3ac102c240 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -973,7 +973,7 @@ def compute_loss( # the assertion if it's a problem in future. (previously we used losses that # required the different copies to be in sync on the time dimension, e.g. # to use the same time warping; we don't do this any more.) - assert num_copies == 2 + #assert num_copies == 2 batch_size = features.shape[0] features = augmentation(features, feature_lens) else: From c3e0e8c6b441964e57b6811c69f72f1a6f787c3a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Mar 2026 16:46:42 +0800 Subject: [PATCH 0988/1191] Changes to multicopy_dataset.py and asr_datamodule.py; max-duration is now a total across all copies, and --num-copies is command-line arg; remove soft links to ../zipformer/; much code cleanup. --- .../ASR/zapformer/alternating_spec_augment.py | 213 +- .../ASR/zapformer/asr_datamodule.py | 25 +- .../ASR/zapformer/attention_decoder.py | 1 - .../ASR/zapformer/batched_rubik.py | 11 +- egs/librispeech/ASR/zapformer/beam_search.py | 3184 ++++++++++++++++- egs/librispeech/ASR/zapformer/ctc_decode.py | 36 +- egs/librispeech/ASR/zapformer/decode.py | 32 +- .../ASR/zapformer/decode_gigaspeech.py | 1 - .../ASR/zapformer/decode_stream.py | 1 - egs/librispeech/ASR/zapformer/decoder.py | 1 - .../ASR/zapformer/encoder_interface.py | 44 +- .../ASR/zapformer/export-onnx-ctc.py | 1 - .../zapformer/export-onnx-streaming-ctc.py | 1 - .../ASR/zapformer/export-onnx-streaming.py | 1 - egs/librispeech/ASR/zapformer/export-onnx.py | 1 - egs/librispeech/ASR/zapformer/export.py | 1 - egs/librispeech/ASR/zapformer/finetune.py | 1 - .../ASR/zapformer/generate_averaged_model.py | 1 - .../ASR/zapformer/jit_pretrained.py | 1 - .../ASR/zapformer/jit_pretrained_ctc.py | 438 ++- .../ASR/zapformer/jit_pretrained_streaming.py | 272 +- egs/librispeech/ASR/zapformer/joiner.py | 70 +- .../ASR/zapformer/label_smoothing.py | 110 +- egs/librispeech/ASR/zapformer/model.py | 2 +- .../ASR/zapformer/multicopy_dataset.py | 35 +- egs/librispeech/ASR/zapformer/muon.py | 285 +- egs/librispeech/ASR/zapformer/my_profile.py | 141 +- egs/librispeech/ASR/zapformer/onnx_check.py | 239 +- egs/librispeech/ASR/zapformer/onnx_decode.py | 325 +- .../onnx_pretrained-streaming-ctc.py | 428 ++- .../zapformer/onnx_pretrained-streaming.py | 548 ++- .../ASR/zapformer/onnx_pretrained.py | 423 ++- .../ASR/zapformer/onnx_pretrained_ctc.py | 215 +- .../ASR/zapformer/onnx_pretrained_ctc_H.py | 278 +- .../ASR/zapformer/onnx_pretrained_ctc_HL.py | 276 +- .../ASR/zapformer/onnx_pretrained_ctc_HLG.py | 276 +- .../onnx_pretrained_ctc_HLG_streaming.py | 440 ++- egs/librispeech/ASR/zapformer/optim.py | 1 - egs/librispeech/ASR/zapformer/pretrained.py | 381 +- .../ASR/zapformer/pretrained_ctc.py | 481 ++- egs/librispeech/ASR/zapformer/scaling.py | 1296 ++++++- .../ASR/zapformer/scaling_converter.py | 100 +- .../ASR/zapformer/streaming_beam_search.py | 296 +- .../ASR/zapformer/streaming_decode.py | 16 +- egs/librispeech/ASR/zapformer/subsampling.py | 402 ++- .../ASR/zapformer/test_subsampling.py | 151 +- egs/librispeech/ASR/zapformer/train.py | 28 +- egs/librispeech/ASR/zapformer/zapformer.py | 2078 +++++++++++ .../ASR/zapformer/zapformer_modules.py | 999 ++++++ .../ASR/zapformer/zapformer_utils.py | 181 + egs/librispeech/ASR/zapformer/zipformer.py | 1 - 51 files changed, 14432 insertions(+), 337 deletions(-) rename icefall/exp_augment.py => egs/librispeech/ASR/zapformer/alternating_spec_augment.py (66%) delete mode 120000 egs/librispeech/ASR/zapformer/attention_decoder.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/beam_search.py delete mode 120000 egs/librispeech/ASR/zapformer/decode_gigaspeech.py delete mode 120000 egs/librispeech/ASR/zapformer/decode_stream.py delete mode 120000 egs/librispeech/ASR/zapformer/decoder.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/encoder_interface.py delete mode 120000 egs/librispeech/ASR/zapformer/export-onnx-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer/export-onnx-streaming.py delete mode 120000 egs/librispeech/ASR/zapformer/export-onnx.py delete mode 120000 egs/librispeech/ASR/zapformer/export.py delete mode 120000 egs/librispeech/ASR/zapformer/finetune.py delete mode 120000 egs/librispeech/ASR/zapformer/generate_averaged_model.py delete mode 120000 egs/librispeech/ASR/zapformer/jit_pretrained.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/joiner.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/label_smoothing.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/muon.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/my_profile.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_check.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_decode.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py delete mode 120000 egs/librispeech/ASR/zapformer/optim.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/pretrained.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/pretrained_ctc.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/scaling.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/scaling_converter.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/streaming_beam_search.py mode change 120000 => 100644 egs/librispeech/ASR/zapformer/subsampling.py mode change 120000 => 100755 egs/librispeech/ASR/zapformer/test_subsampling.py create mode 100644 egs/librispeech/ASR/zapformer/zapformer.py create mode 100644 egs/librispeech/ASR/zapformer/zapformer_modules.py create mode 100644 egs/librispeech/ASR/zapformer/zapformer_utils.py delete mode 120000 egs/librispeech/ASR/zapformer/zipformer.py diff --git a/icefall/exp_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py similarity index 66% rename from icefall/exp_augment.py rename to egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 1bfb97e576..6bf6038254 100644 --- a/icefall/exp_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -5,16 +5,22 @@ -class ExpAugment(torch.nn.Module): +class AlternatingSpecAugment(torch.nn.Module): """ - ExpAugment is a different, simpler implementation of the feature-masking and frame-masking - aspects of SpecAugment, without the time warping for now. + AlternatingSpecAugment is a different version of feature-masking and frame-masking + aspects of SpecAugment, without the time warping for now (we use time_warp + from lhotse which is the same as the original SpecAugment). + + The main difference is in how it selects the regions to be masked, they are selected + for pairs of sequences in such a way that there tends to be a good amount of spacing between + masked regions; the masked regions never overlap and will never be extremely close to + each other. We also use a relatively large masked-fraction """ def __init__( self, - max_feature_mask_fraction: float = 0.675, # max fraction that can possibly be masked + max_feature_mask_fraction: float = 0.675, # max fraction that can possibly be masked; the expected masked-fraction is half of this. num_feature_masks: int = 2, - max_frame_mask_fraction: float = 0.725, + max_frame_mask_fraction: float = 0.725, # the expected temporal masked-fraction is half of this. max_frame_mask_size: float = 70, # max size in frames of temporal masks. p=0.9, # probability of doing augmentation ): @@ -235,200 +241,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]): setattr(self, name, state_dict["name"]) - -def hz_to_mel(hz: torch.Tensor): - return 1127.0 * torch.log(1 + hz / 700) - - -def mel_to_hz(mel: torch.Tensor): - return 700 * ((mel / 1127).exp() - 1) - - -def compute_mel_normalized_indexes( - low_freq_hz: float, - high_freq_hz: float, - sample_rate_hz: float, - num_mel_bins: float, - shift: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Return a tuple containing normalized indexes. - - - The first tensor is for expansion, i.e., map the second-to-last - bin to the last bin - - - The second tensor is for contraction, i.e., map the last bin to - the second-to-last bin - """ - nyquist = sample_rate_hz * 0.5 - if high_freq_hz <= 0: - high_freq_hz = nyquist + high_freq_hz - - assert 0 <= low_freq_hz < high_freq_hz <= nyquist, ( - low_freq_hz, - high_freq_hz, - nyquist, - sample_rate_hz, - ) - assert num_mel_bins > 1, num_mel_bins - - low_high_mel = hz_to_mel( - torch.tensor([low_freq_hz, high_freq_hz], dtype=torch.float32) - ) - - # divided by num_mel_bins + 1 to match the one used in Kaldi - mel_freq_delta = (low_high_mel[1] - low_high_mel[0]) / (num_mel_bins + 1) - - # the formulate to compute the mel tensor below is from Kaldi - mel = low_high_mel[0] + mel_freq_delta * torch.arange(num_mel_bins) - - hz = mel_to_hz(mel) - - expansion_scale = hz[-1] / hz[-1 - shift] # e.g. 1.0338 - contraction_scale = 1 / expansion_scale # e.g., 0.9673 - - mel_expanded = hz_to_mel(hz * expansion_scale) - mel_contracted = hz_to_mel(hz * contraction_scale) - - mel_expanded_indexes = (mel_expanded - low_high_mel[0]) / mel_freq_delta - mel_contracted_indexes = (mel_contracted - low_high_mel[0]) / mel_freq_delta - - mel_expanded_normalized_indexes = mel_expanded_indexes * 2 / (num_mel_bins - 1) - 1 - - mel_contracted_normalized_indexes = ( - mel_contracted_indexes * 2 / (num_mel_bins - 1) - 1 - ) - - return mel_expanded_normalized_indexes, mel_contracted_normalized_indexes - - -class MelWarp(torch.nn.Module): - def __init__( - self, - low_freq_hz: float, - high_freq_hz: float, - sample_rate_hz: float, - num_mel_bins: int, - p: float, - max_shift: int = 1, - ): - super().__init__() - - assert 0 <= p <= 1, p - assert 1 <= max_shift < num_mel_bins - 1 - - indexes = [] - for i in range(1, max_shift + 1): - expansion_indexes, contraction_indexes = compute_mel_normalized_indexes( - low_freq_hz=low_freq_hz, - high_freq_hz=high_freq_hz, - sample_rate_hz=sample_rate_hz, - num_mel_bins=num_mel_bins, - shift=i, - ) - indexes.append(expansion_indexes) - indexes.append(contraction_indexes) - - self.register_buffer('indexes', torch.stack(indexes, dim=0)) - - self.num_mel_bins = num_mel_bins - self.p = p - - def forward(self, features: torch.Tensor) -> torch.Tensor: - B, T, C = features.shape - assert C == self.num_mel_bins, (C, self.num_mel_bins) - - device = features.device - - features = features.permute(0, 2, 1) - - # grid sample requires (N,C,H,W) input - # we treat the feature axis as h, the time axis as w - # and use 1 for the channel in NCHW - - h = torch.linspace(-1, 1, C, device=device)[None, :, None].expand(B, C, T).to(device) - - # select a different index for each audio in the batch - # where each index corresponds to a shift - index = torch.randint( - low=0, high=self.indexes.shape[0], size=(B,), dtype=torch.int64, device=device, - ) - - warped_indexes = self.indexes[index][:, :, None].expand(B, C, T).to(device) - - h_positions = torch.where( - torch.rand(B, 1, 1, device=device).expand_as(features) < self.p, - warped_indexes, - h, - ) - - w = torch.linspace(-1, 1, T, device=device)[None, None, :].expand(B, C, T) - - grid = torch.stack([w, h_positions], axis=-1) - - features = torch.nn.functional.grid_sample( - features.unsqueeze(1), - grid, - mode="bicubic", - padding_mode="border", - align_corners=True, - ) - return features.squeeze(1).permute(0, 2, 1) - - -def _test_grid_sample(): - f = torch.rand(50, 20, 80) # (batch, time, features) - B, T, C = f.shape - - h = torch.linspace(-1, 1, C)[None, :, None].expand(B, C, T) - w = torch.linspace(-1, 1, T)[None, None, :].expand(B, C, T) - # w is x - # h is y - grid = torch.stack([w, h], axis=-1) - f2 = [] - for aligned in [True, False]: - f2.append( - torch.nn.functional.grid_sample( - f.permute(0, 2, 1).unsqueeze(1), - grid, - mode="bicubic", - padding_mode="border", - align_corners=aligned, - ) - .squeeze(1) - .permute(0, 2, 1) - ) - print("align_corners=true", (f - f2[0]).abs().max()) # aligned true - print("align_corners=false", (f - f2[1]).abs().max()) # aligned false - - -def _test_mel_warp(): - # The parameters used in testing are default values in lhotse - mel_warp = MelWarp( - low_freq_hz=20, - high_freq_hz=-400, - sample_rate_hz=16000, - num_mel_bins=80, - p=1, - max_shift=4, - ) - - f0 = torch.rand(2, 20, 80) * 10 - f1 = mel_warp(f0) - - assert f0.shape == f1.shape - print((f0 - f1).abs().max()) - - - -def _test_exp_augment(): +def _test_alternating_spec_augment(): for n in [ 0, 1 ]: #device = 'cuda' B, T, F = 301, 600, 80 device = 'cpu' if n == 0: - exp_augment = ExpAugment() #, max_frame_mask_size=2.0, max_frame_mask_fraction=0.02) + aspec_augment = AlternatingSpecAugment() else: from lhotse.dataset import SpecAugment time_mask_ratio = 3.5 @@ -452,12 +272,12 @@ def _test_exp_augment(): torch.zeros(B, device=device, dtype=torch.long), # start_frame T * torch.ones(B, device=device, dtype=torch.long) # num_frames ), dim=-1) - exp_augment = lambda x: spec_augment(x, supervision_segments) + aspec_augment = lambda x: spec_augment(x, supervision_segments) features = torch.randn(B, T, F, device=device) lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) #print("features=", features) - features = exp_augment(features) + features = aspec_augment(features) frame_is_masked = features[:, :, 0] == features[:, :, -1] print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) @@ -469,5 +289,4 @@ def _test_exp_augment(): # from lhotse.dataset import SpecAugment if __name__ == '__main__': - _test_exp_augment() - _test_mel_warp() + _test_alternating_spec_augment() diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 2b6d7fe132..12a894e818 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -33,7 +33,7 @@ SimpleCutSampler, ) # MulticopyDataset is a modified version of K2SpeechRecognitionDataset from -# lhotse.dataset, modified to, in training mode, to return a batch that has 2 +# lhotse.dataset, modified to, in training mode, to return a batch that has multiple # different copies of the same data having different Musan # augmentations and the first having none; and also include the key "num_copies" # in the batch which would be 1 for the validation data (no Musan) and 2 for the @@ -116,9 +116,10 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--max-duration", type=int, - default=200.0, + default=800.0, help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + "single batch, including the --num-copies argument, so if --num-copies " + "is larger the actual duration prior to making copies will be smaller." ) group.add_argument( "--bucketing-sampler", @@ -209,6 +210,15 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="AudioSamples or PrecomputedFeatures", ) + group.add_argument( + "--num-copies", + type=str, + default=4, + help="The number of copies of each training example selected in each batch (they will be augmented " + "differently). If you make num-copies larger there will be more steps per epoch so you should probably make " + "num-epochs smaller. " + ) + parser.add_argument( "--libri-copies", type=int, @@ -270,6 +280,7 @@ def train_dataloaders( logging.info("About to create train dataset") train = MulticopyDataset( + num_copies=self.args.num_copies, input_strategy=eval(self.args.input_strategy)(), cut_transforms=transforms, input_transforms=[], @@ -298,7 +309,7 @@ def train_dataloaders( logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, - max_duration=self.args.max_duration, + max_duration=self.args.max_duration / self.args.num_copies, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, buffer_size=self.args.num_buckets * 2000, @@ -309,7 +320,7 @@ def train_dataloaders( logging.info("Using SimpleCutSampler.") train_sampler = SimpleCutSampler( cuts_train, - max_duration=self.args.max_duration, + max_duration=self.args.max_duration / self.args.num_copies, shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") @@ -346,12 +357,14 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: validate = MulticopyDataset( + num_copies=1, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: validate = MulticopyDataset( + num_copies=1, cut_transforms=transforms, return_cuts=self.args.return_cuts, ) @@ -576,4 +589,4 @@ def test_cuts(self) -> CutSet: logging.info("CommonVoice: About to get test cuts") return load_manifest_lazy( self.manifest_dir / "cv22-en_cuts_test.jsonl.gz" - ) \ No newline at end of file + ) diff --git a/egs/librispeech/ASR/zapformer/attention_decoder.py b/egs/librispeech/ASR/zapformer/attention_decoder.py deleted file mode 120000 index 830180a0cd..0000000000 --- a/egs/librispeech/ASR/zapformer/attention_decoder.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 42471da4b2..e3ed9ecc04 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -351,11 +351,7 @@ def adam_step(group, state, grad): class BatchedRubik(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - + Implements a batched version of the Rubik optimizer. Args: params: The parameters or param_groups to optimize (like other Optimizer subclasses) @@ -564,8 +560,6 @@ def step(self, closure=None): def _test_batched_rubik(hidden_dim: int): import timeit - from scaling import OrthogonalLinear - E = 100 B = 4 T = 2 @@ -675,7 +669,6 @@ def _test_muon(hidden_dim: int): import timeit from muon import Muon - from scaling import OrthogonalLinear E = 100 B = 4 @@ -698,8 +691,6 @@ def _test_muon(hidden_dim: int): m = torch.nn.Sequential( Linear(E, hidden_dim), - OrthogonalLinear(hidden_dim, hidden_dim, bias=True, - in_groups=2, group_size=hidden_dim//4), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), torch.nn.PReLU(), diff --git a/egs/librispeech/ASR/zapformer/beam_search.py b/egs/librispeech/ASR/zapformer/beam_search.py deleted file mode 120000 index 8554e44ccf..0000000000 --- a/egs/librispeech/ASR/zapformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/beam_search.py b/egs/librispeech/ASR/zapformer/beam_search.py new file mode 100644 index 0000000000..66c84b2a94 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/beam_search.py @@ -0,0 +1,3183 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +from torch import nn + +from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost +from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer +from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import ( + DecodingResults, + KeywordResult, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) + + +def fast_beam_search_one_best( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + ilme_scale: float = 0.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ilme_scale=ilme_scale, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_LG( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + ilme_scale: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ilme_scale=ilme_scale, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + blank_penalty=blank_penalty, + temperature=temperature, + allow_partial=allow_partial, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_oracle( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, + allow_partial: bool = False, + blank_penalty: float = 0.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) + + if ilme_scale != 0: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + if blank_penalty != 0: + ilme_logits[:, 0] -= blank_penalty + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs + + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output( + encoder_out_lens.tolist(), allow_partial=allow_partial + ) + + return lattice + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + y = logits.argmax().item() + if y not in (blank_id, unk_id): + hyp.append(y) + timestamp.append(t) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + if not return_timestamps: + return hyp + else: + return DecodingResults( + hyps=[hyp], + timestamps=[timestamp], + ) + + +def greedy_search_batch( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_penalty: float = 0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] + + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + timestamps[i].append(t) + scores[i].append(logits[i, v].item()) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + ans_timestamps = [] + ans_scores = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) + ans_scores.append(scores[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + scores=ans_scores, + ) + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + ac_probs: Optional[List[float]] = None + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state + state_cost: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + num_tailing_blanks: int = 0 + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def keywords_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + keywords_graph: ContextGraph, + beam: int = 4, + num_tailing_blanks: int = 0, + blank_penalty: float = 0, +) -> List[List[KeywordResult]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + keywords_graph: + A instance of ContextGraph containing keywords and their configurations. + beam: + Number of active paths during the beam search. + num_tailing_blanks: + The number of tailing blanks a keyword should be followed, this is for the + scenario that a keyword will be the prefix of another. In most cases, you + can just set it to 0. + blank_penalty: + The score used to penalize blank probability. + Returns: + Return a list of list of KeywordResult. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert keywords_graph is not None + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=keywords_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + sorted_ans = [[] for _ in range(N)] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + probs = logits.softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs = probs.log() + + probs = probs.reshape(-1) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + hyp_probs = ragged_probs[i].tolist() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + new_ac_probs = hyp.ac_probs[:] + context_score = 0 + new_context_state = hyp.context_state + new_num_tailing_blanks = hyp.num_tailing_blanks + 1 + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_ac_probs.append(hyp_probs[topk_indexes[k]]) + ( + context_score, + new_context_state, + _, + ) = keywords_graph.forward_one_step(hyp.context_state, new_token) + new_num_tailing_blanks = 0 + if new_context_state.token == -1: # root + new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ac_probs=new_ac_probs, + context_state=new_context_state, + num_tailing_blanks=new_num_tailing_blanks, + ) + B[i].add(new_hyp) + + top_hyp = B[i].get_most_probable(length_norm=True) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if ( + matched + and top_hyp.num_tailing_blanks > num_tailing_blanks + and ac_prob >= matched_state.ac_threshold + ): + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + B[i] = HypothesisList() + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=keywords_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + B = B + finalized_B + + for i, hyps in enumerate(B): + top_hyp = hyps.get_most_probable(length_norm=True) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if matched and ac_prob >= matched_state.ac_threshold: + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + return ans + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_graph: Optional[ContextGraph] = None, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=None if context_graph is None else context_graph.root, + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) + + +def modified_beam_search_lm_rescore( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + # get the best hyp with different lm_scale + for lm_scale in lm_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}" + tot_scores = am_scores.values + lm_scores * lm_scale + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def modified_beam_search_lm_rescore_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def _deprecated_modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] + new_token = topk_token_indexes[i] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + # TODO(fangjun): Scale the blank posterior + log_prob = (logits / temperature).log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i in (blank_id, unk_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def fast_beam_search_with_nbest_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def modified_beam_search_ngram_rescoring( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LODR_lm: NgramLm, + LODR_lm_scale: float, + LM: LmScorer, + beam: int = 4, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, # state of the NN LM + lm_score=init_score.reshape(-1), + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram + context_state=None if context_graph is None else context_graph.root, + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + LM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + # forward NN LM to get new states and scores + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + + context_score + ) # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_lm_shallow_fusion( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py index f3bce1b43d..cbbc7313d1 100755 --- a/egs/librispeech/ASR/zapformer/ctc_decode.py +++ b/egs/librispeech/ASR/zapformer/ctc_decode.py @@ -22,48 +22,48 @@ Usage: (1) ctc-greedy-search -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --max-duration 600 \ --decoding-method ctc-greedy-search (2) ctc-decoding -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --max-duration 600 \ --decoding-method ctc-decoding (3) 1best -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --max-duration 600 \ --hlg-scale 0.6 \ --decoding-method 1best (4) nbest -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --max-duration 600 \ --hlg-scale 0.6 \ --decoding-method nbest (5) nbest-rescoring -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --max-duration 600 \ --hlg-scale 0.6 \ @@ -72,10 +72,10 @@ --decoding-method nbest-rescoring (6) whole-lattice-rescoring -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --max-duration 600 \ --hlg-scale 0.6 \ @@ -84,20 +84,20 @@ --decoding-method whole-lattice-rescoring (7) attention-decoder-rescoring-no-ngram -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --use-attention-decoder 1 \ --max-duration 100 \ --decoding-method attention-decoder-rescoring-no-ngram (8) attention-decoder-rescoring-with-ngram -./zipformer/ctc_decode.py \ +./zapformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --use-ctc 1 \ --use-attention-decoder 1 \ --max-duration 100 \ @@ -267,7 +267,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="zapformer/exp", help="The experiment dir", ) @@ -529,7 +529,7 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. + # this seems to cause insertions at the end of the utterance if used with zapformer. pad_len = 30 feature_lens += pad_len feature = torch.nn.functional.pad( diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py index 90c7a4a309..d7cb11e752 100755 --- a/egs/librispeech/ASR/zapformer/decode.py +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -19,36 +19,36 @@ """ Usage: (1) greedy search -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -56,10 +56,10 @@ --max-states 64 (5) fast beam search (nbest) -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -69,10 +69,10 @@ --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -82,10 +82,10 @@ --nbest-scale 0.5 (7) fast beam search (with LG) -./zipformer/decode.py \ +./zapformer/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ @@ -265,7 +265,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="zapformer/exp", help="The experiment dir", ) @@ -532,7 +532,7 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. + # this seems to cause insertions at the end of the utterance if used with zapformer. pad_len = 30 feature_lens += pad_len feature = torch.nn.functional.pad( diff --git a/egs/librispeech/ASR/zapformer/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer/decode_gigaspeech.py deleted file mode 120000 index 63b0ef617b..0000000000 --- a/egs/librispeech/ASR/zapformer/decode_gigaspeech.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decode_stream.py b/egs/librispeech/ASR/zapformer/decode_stream.py deleted file mode 120000 index 4e59d04a12..0000000000 --- a/egs/librispeech/ASR/zapformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decoder.py b/egs/librispeech/ASR/zapformer/decoder.py deleted file mode 120000 index cab465d2b9..0000000000 --- a/egs/librispeech/ASR/zapformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/encoder_interface.py b/egs/librispeech/ASR/zapformer/encoder_interface.py deleted file mode 120000 index aa5d0217a8..0000000000 --- a/egs/librispeech/ASR/zapformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/encoder_interface.py b/egs/librispeech/ASR/zapformer/encoder_interface.py new file mode 100644 index 0000000000..257facce4f --- /dev/null +++ b/egs/librispeech/ASR/zapformer/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/librispeech/ASR/zapformer/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer/export-onnx-ctc.py deleted file mode 120000 index dc14e93e75..0000000000 --- a/egs/librispeech/ASR/zapformer/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py deleted file mode 120000 index 3baa2b673c..0000000000 --- a/egs/librispeech/ASR/zapformer/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py deleted file mode 120000 index d18cb9a9a1..0000000000 --- a/egs/librispeech/ASR/zapformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py deleted file mode 120000 index f343cf7027..0000000000 --- a/egs/librispeech/ASR/zapformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/export.py b/egs/librispeech/ASR/zapformer/export.py deleted file mode 120000 index 1a126ab695..0000000000 --- a/egs/librispeech/ASR/zapformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/finetune.py b/egs/librispeech/ASR/zapformer/finetune.py deleted file mode 120000 index 0e9e7989b9..0000000000 --- a/egs/librispeech/ASR/zapformer/finetune.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/generate_averaged_model.py b/egs/librispeech/ASR/zapformer/generate_averaged_model.py deleted file mode 120000 index b65513a058..0000000000 --- a/egs/librispeech/ASR/zapformer/generate_averaged_model.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained.py b/egs/librispeech/ASR/zapformer/jit_pretrained.py deleted file mode 120000 index 5d45825206..0000000000 --- a/egs/librispeech/ASR/zapformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py deleted file mode 120000 index 43aeb684bf..0000000000 --- a/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py new file mode 100755 index 0000000000..1430b97109 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) whole-lattice-rescoring +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from ctc_decode import get_decoding_params +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a token table, + i.e., lang_dir/token.txt, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + nbest n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + whole-lattice n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(features, feature_lengths) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = ctc_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i].item() // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = params.vocab_size - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = [[token_table[i] for i in ids] for ids in token_ids] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py deleted file mode 120000 index 8e5e6f9812..0000000000 --- a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py new file mode 100755 index 0000000000..9d85756b1e --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zapformer/jit_pretrained_streaming.py \ + --nn-model-filename ./zapformer/exp-causal/jit_script_chunk_16_left_128.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ +""" + +import argparse +import logging +from typing import List, Optional + +import k2 +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model jit_script.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, + device: torch.device = torch.device("cpu"), +): + assert encoder_out.ndim == 2 + context_size = decoder.context_size + blank_id = decoder.blank_id + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor( + decoder_input, dtype=torch.int32, device=device + ).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.eval() + model.to(device) + + encoder = model.encoder + decoder = model.decoder + joiner = model.joiner + + token_table = k2.SymbolTable.from_file(args.tokens) + context_size = decoder.context_size + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + + chunk_length = encoder.chunk_size * 2 + T = chunk_length + encoder.pad_length + + logging.info(f"chunk_length: {chunk_length}") + logging.info(f"T: {T}") + + states = encoder.get_init_states(device=device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32, device=device) + encoder_out, out_lens, states = encoder( + features=frames, + feature_lengths=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device + ) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/joiner.py b/egs/librispeech/ASR/zapformer/joiner.py deleted file mode 120000 index 444cb5f150..0000000000 --- a/egs/librispeech/ASR/zapformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/joiner.py b/egs/librispeech/ASR/zapformer/joiner.py new file mode 100644 index 0000000000..5cf7b42bd2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/joiner.py @@ -0,0 +1,69 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from zapformer_modules import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + logit = encoder_out + decoder_out + + # the scale of 2.0 is arbitrary, it is intended to modulate the speed at which joiner.output_linear trains, + # make it train faster by reducing its scale. + logit = 2.0 * self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/librispeech/ASR/zapformer/label_smoothing.py b/egs/librispeech/ASR/zapformer/label_smoothing.py deleted file mode 120000 index 3690afff9d..0000000000 --- a/egs/librispeech/ASR/zapformer/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/label_smoothing.py b/egs/librispeech/ASR/zapformer/label_smoothing.py new file mode 100644 index 0000000000..52d2eda3bb --- /dev/null +++ b/egs/librispeech/ASR/zapformer/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index e04a10aa7d..61ac25e3d3 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -24,7 +24,7 @@ from torch import Tensor from encoder_interface import EncoderInterface -from scaling import ScaledLinear, convert_num_channels +from zapformer_modules import ScaledLinear from icefall.utils import add_sos, make_pad_mask diff --git a/egs/librispeech/ASR/zapformer/multicopy_dataset.py b/egs/librispeech/ASR/zapformer/multicopy_dataset.py index f445adbe1f..2e6f145690 100755 --- a/egs/librispeech/ASR/zapformer/multicopy_dataset.py +++ b/egs/librispeech/ASR/zapformer/multicopy_dataset.py @@ -13,22 +13,11 @@ class MulticopyDataset(torch.utils.data.Dataset): """ This is slightly modified from lhotse's K2SpeechRecognitionDataset, but - modified as suggested by Piotr in this github thread: + to support multiple parallel copies of the data, with augmentation applied + differently. + It uses ideas from Piotr in this thread: https://github.com/k2-fsa/icefall/pull/1975 - If cut_transforms is specified, which will normally be the case for training - data, where you might specify Musan augmentation, it returns two copies of - the data that differ only in the augmentations, followed by a third unmodified - copy. The structure of the data would be [ a b c d a b c d a b c d ], i.e. - the order is: first copy of all buts, second copy of all cuts, unmodified - copy of all cuts. - If cut_transforms is not specified, this dataset behaves like lhotse's regular - K2SpeechRecognitionDataset. - The yielded dict will have an extra key called "num_copies", set to 3 if - we did the 2 augmentation copies plus one original copy, or 1 if there - were no augmentations. - - This dataset expects to be queried with lists of cut IDs, for which it loads features and automatically collates/batches them. @@ -75,6 +64,7 @@ class MulticopyDataset(torch.utils.data.Dataset): def __init__( self, return_cuts: bool = False, + num_copies: int = 1, cut_transforms: List[Callable[[CutSet], CutSet]] = None, input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: BatchIO = PrecomputedFeatures(), @@ -99,6 +89,7 @@ def __init__( self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, []) self.input_strategy = input_strategy + self.num_copies = num_copies # This attribute is a workaround to constantly growing HDF5 memory # throughout the epoch. It regularly closes open file handles to @@ -117,19 +108,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]] # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) - if self.cut_transforms: - orig_cuts = cuts - - cuts = cuts.repeat(times=4) - - for tnfm in self.cut_transforms: - cuts = tnfm(cuts) - - #cuts = orig_cuts + cuts - num_copies = 4 - else: - num_copies = 1 - + cuts = cuts.repeat(times=self.num_copies) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. @@ -155,7 +134,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]] batch = { "inputs": inputs, - "num_copies": num_copies, + "num_copies": self.num_copies, "supervisions": default_collate( [ { diff --git a/egs/librispeech/ASR/zapformer/muon.py b/egs/librispeech/ASR/zapformer/muon.py deleted file mode 120000 index 847edc7f4c..0000000000 --- a/egs/librispeech/ASR/zapformer/muon.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/muon.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/muon.py b/egs/librispeech/ASR/zapformer/muon.py new file mode 100644 index 0000000000..df69d1c166 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/muon.py @@ -0,0 +1,284 @@ +# Copyright 2025 Moonshot AI and the LlamaFactory team. +# +# This code is based on the MoonshotAI's Moonlight library. +# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py +# and the Keller Jordan's Muon library. +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2025 Moonshot AI +# Copyright (c) 2024 Keller Jordan +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +import torch +import logging +import random + + + + +def norm4(X): + XX = X @ X.T + if random.random() < 0.0001: + norm2 = X.norm() + norm4 = XX.norm().sqrt() + logging.info(f"shape={X.shape}, norm2={norm2} vs norm4={norm4}") + return XX.norm().sqrt() + +def get_muon_shape(shape): + shape = list(shape) + def prod(l): + ans = l[0] + for n in l[1:]: + ans = ans * n + return ans + n = len(shape) + diffs = [ ] + for i in range(1, n): + prod1 = prod(shape[:i]) + prod2 = prod(shape[i:]) + diff = abs(prod1 - prod2) + diffs.append(diff) + min_diff = min(diffs) + for i in range(1, n): + if diffs[i-1] == min_diff: + return prod(shape[:i]), prod(shape[i:]) + +def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int, state: dict) -> "torch.Tensor": + """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + + We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. + For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing + the slope at zero even beyond the point where the iteration no longer converges all the way to + one everywhere on the interval. This iteration therefore does not produce UV^T but rather something + like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + orig_shape = G.shape + G = G.reshape(get_muon_shape(orig_shape)) + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + + if "delta2_buffer0" not in state: + state["delta2_buffer0"] = torch.ones(X.shape[0], device=X.device, dtype=X.dtype) + state["delta2_buffer1"] = torch.ones(X.shape[1], device=X.device, dtype=X.dtype) + delta2_buffer0 = state["delta2_buffer0"] + delta2_buffer1 = state["delta2_buffer1"] + + + eps = 1e-7 + + # we'll scale both before and after the newton-schulz + row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt().unsqueeze(-1) * (delta2_buffer1 + eps).sqrt()) + X = X * row_col_scale + + # Ensure spectral 4-norm is at most 1 + X = X / (norm4(X) + eps) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + # the following scales so if the newton-schulz was exact, the elements of X would have unit RMS. + X = X * (max(X.shape[0], X.shape[1]) ** 0.5) + X2 = X ** 2 + beta = 0.98 + delta2_buffer0.mul_(beta).add_(X2.mean(dim=1), alpha=(1 - beta)) + delta2_buffer1.mul_(beta).add_(X2.mean(dim=0), alpha=(1 - beta)) + + X = X * row_col_scale + + if G.size(0) > G.size(1): + X = X.T + + return X.reshape(orig_shape) + + +class Muon(torch.optim.Optimizer): + """Muon - MomentUm Orthogonalized by Newton-schulz. + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + wd: weight decay for muon and adamw, this is a squared type of weight decay, requires a large value + which dimensionally is like an inverse of a parameter rms + """ + def __init__( + self, + params, + lr=1e-3, + wd=10.0, # weight decay is a squared type, needs larger wd value, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + scale_limits=(0.5, 4.0), + ): + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + scale_limits=scale_limits, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + # Muon loop + params = [p for p in group["params"] if p.numel() != max(p.shape, default=1)] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + min_scale, max_scale = group["scale_limits"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + state["scale"] = torch.tensor(1.0, device=g.device) # scalar + state["scale_grad_buffer"] = torch.tensor(0.0, device=g.device) # scalar + buf = state["momentum_buffer"] + scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] + buf.mul_(momentum).add_(g) + + scale_grad = (g * p.detach()).sum() + scale_grad_buf.mul_(momentum).add_(scale_grad) + + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + eps = 1.0e-08 + + + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], state=state) + + # multiplying by 0.2 is what's left of adjust_lr_for_muon(), + # we used the factor of (max(p.shape[0], p.shape[1]) ** 0.5) inside + # zeropower_via_newtonschulz5. + adjusted_lr = 0.2 * lr + + old_scale = scale.clone() + + scale.add_(scale_grad_buf.sign(), alpha=-lr) + scale.clamp_(min=min_scale, max=max_scale) + + scale_ratio = scale / old_scale + + # apply changes in scale, together with conventional decay. + p.data.mul_(scale_ratio * (1 - (lr * wd) ** 2)) + + # apply update + p.data.add_(u * scale, alpha=-adjusted_lr) + + # Adam backup + params = [p for p in group["params"] if p.numel() == max(p.shape, default=1)] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - (lr * weight_decay) ** 2) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/egs/librispeech/ASR/zapformer/my_profile.py b/egs/librispeech/ASR/zapformer/my_profile.py deleted file mode 120000 index 76e48b756b..0000000000 --- a/egs/librispeech/ASR/zapformer/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/my_profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/my_profile.py b/egs/librispeech/ASR/zapformer/my_profile.py new file mode 100755 index 0000000000..333b139689 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/my_profile.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: ./zapformer/my_profile.py +""" + +import argparse +import logging +from typing import Tuple + +import sentencepiece as spm +import torch +from torch import Tensor, nn +from train import ( + add_model_arguments, + get_encoder_embed, + get_encoder_model, + get_joiner_model, + get_params, +) +from zapformer import BypassModule + +from icefall.profiler import get_model_profile +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + add_model_arguments(parser) + + return parser + +class Model(nn.Module): + """A Wrapper for encoder, encoder_embed, and encoder_proj""" + + def __init__( + self, + encoder: nn.Module, + encoder_embed: nn.Module, + encoder_proj: nn.Module, + ) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: + x, x_lens = self.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + logits = self.encoder_proj(encoder_out) + + return logits, encoder_out_lens + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + + # We only profile the encoder part + model = Model( + encoder=get_encoder_model(params), + encoder_embed=get_encoder_embed(params), + encoder_proj=get_joiner_model(params).encoder_proj, + ) + model.eval() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # for 30-second input + B, T, D = 1, 3000, 80 + feature = torch.ones(B, T, D, dtype=torch.float32).to(device) + feature_lens = torch.full((B,), T, dtype=torch.int64).to(device) + + flops, params = get_model_profile( + model=model, + args=(feature, feature_lens), + #module_hoop_mapping=MODULE_HOOK_MAPPING, + ) + logging.info(f"For the encoder part, params: {params}, flops: {flops}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_check.py b/egs/librispeech/ASR/zapformer/onnx_check.py deleted file mode 120000 index 7293c70d46..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_check.py b/egs/librispeech/ASR/zapformer/onnx_check.py new file mode 100755 index 0000000000..c248ea6487 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_check.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./zapformer/export.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - jit_script.pt + +3. Export the model to ONNX + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./zapformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +import torch +from onnx_pretrained import OnnxModel + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") + + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T + + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) + + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) + + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + test_encoder(torch_model, onnx_model) + + logging.info("Test decoder") + test_decoder(torch_model, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_model, onnx_model) + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_decode.py b/egs/librispeech/ASR/zapformer/onnx_decode.py deleted file mode 120000 index 9e3faa5e01..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_decode.py b/egs/librispeech/ASR/zapformer/onnx_decode.py new file mode 100755 index 0000000000..075474a6bf --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_decode.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +2. Run this file + +./zapformer/onnx_decode.py \ + --exp-dir $repo/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + The token table. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + The token table. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = SymbolTable.from_file(args.tokens) + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py deleted file mode 120000 index f8abb9daa5..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py new file mode 100755 index 0000000000..8d2cebb54c --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zapformer-ctc-streaming-2023-11-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zapformer-ctc-streaming-2023-11-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 + +It will generate the following 2 files inside $repo/exp: + + - ctc-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - ctc-epoch-99-avg-1-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +./zapformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-99-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zapformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +def greedy_search( + log_probs: torch.Tensor, +) -> List[int]: + """Greedy search for a single utterance. + Args: + log_probs: + A 3-D tensor of shape (1, T, vocab_size) + Returns: + Return the decoded result. + """ + assert log_probs.ndim == 3, log_probs.shape + assert log_probs.shape[0] == 1, log_probs.shape + + max_indexes = log_probs[0].argmax(dim=1) + unique_indexes = torch.unique_consecutive(max_indexes) + + blank_id = 0 + unique_indexes = unique_indexes[unique_indexes != blank_id] + return unique_indexes.tolist() + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel(model_filename=args.model_filename) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + hyp = [] + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + log_probs = model(frames) + + hyp += greedy_search(log_probs) + + # To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes + id2token = {} + with open(args.tokens, encoding="utf-8") as f: + for line in f: + token, idx = line.split() + if token[:3] == "<0x" and token[-1] == ">": + token = int(token[1:-1], base=16) + assert 0 <= token < 256, token + token = token.to_bytes(1, byteorder="little") + else: + token = token.encode(encoding="utf-8") + + id2token[int(idx)] = token + + text = b"" + for i in hyp: + text += id2token[i] + text = text.decode(encoding="utf-8") + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py deleted file mode 120000 index 11b846322e..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py new file mode 100755 index 0000000000..2d25805842 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer-2023-05-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer-2023-05-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./zapformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + logging.info(f"encoder_meta={encoder_meta}") + + model_type = encoder_meta["model_type"] + assert model_type == "zapformer2", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + query_head_dims = encoder_meta["query_head_dims"] + value_head_dims = encoder_meta["value_head_dims"] + num_heads = encoder_meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + encoder_input[name] = tensors[0] + encoder_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + encoder_input[name] = tensors[1] + encoder_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + encoder_input[name] = tensors[2] + encoder_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + encoder_input[name] = tensors[3] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + encoder_input[name] = tensors[4] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + encoder_input[name] = tensors[5] + encoder_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + encoder_input[name] = embed_states + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + encoder_input[name] = processed_lens + encoder_output.append(f"new_{name}") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2-3)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + token_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained.py b/egs/librispeech/ASR/zapformer/onnx_pretrained.py deleted file mode 120000 index a085def837..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained.py b/egs/librispeech/ASR/zapformer/onnx_pretrained.py new file mode 100755 index 0000000000..cbbaa27c09 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file + +./zapformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, joiner_dim) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.run_decoder(decoder_input) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) + + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = model.run_decoder(decoder_input) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py deleted file mode 120000 index 0c082a204f..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py new file mode 100755 index 0000000000..457e2370bc --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + blank_id = 0 + s = "\n" + for i in range(log_probs.size(0)): + # greedy search + indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1) + token_ids = torch.unique_consecutive(indexes) + + token_ids = token_ids[token_ids != blank_id] + words = token_ids_to_words(token_ids.tolist()) + s += f"{args.sound_files[i]}:\n{words}\n\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py deleted file mode 120000 index 68102c7374..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py new file mode 100755 index 0000000000..7472c61c5e --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc_H.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + --H /path/to/H.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zapformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import Dict, List, Tuple + +import k2 +import kaldifeat +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--H", + type=str, + help="""Path to H.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2word: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if i != 1] + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + H=H, + id2token=token_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py deleted file mode 120000 index 8314b4efdf..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py new file mode 100755 index 0000000000..9e11535b2b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc_HL.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HL /path/to/HL.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zapformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import Dict, List, Tuple + +import k2 +import kaldifeat +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HL", + type=str, + help="""Path to HL.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HL=HL, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py deleted file mode 120000 index 7a637a1c01..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py new file mode 100755 index 0000000000..3d757386cb --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc_HLG.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HLG /path/to/HLG.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zapformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import Dict, List, Tuple + +import k2 +import kaldifeat +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HLG=HLG, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py deleted file mode 120000 index a5b04b3f8b..0000000000 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py new file mode 100755 index 0000000000..e823c8d5a2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zapformer-small-2024-03-18 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zapformer-small-2024-03-18 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp-ctc-rnnt-small/*.pt" +git lfs pull --include "data/lang_bpe_500/words.txt" +git lfs pull --include "data/lang_bpe_500/HLG.fst" +popd + +2. Export the model to ONNX + +./zapformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 3 \ + --exp-dir $repo/exp-ctc-rnnt-small \ + --causal 1 \ + --use-ctc 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 + +It will generate the following 2 files inside $repo/exp-ctc-rnnt-small: + + - ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx + - ctc-epoch-30-avg-3-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +python3 ./zapformer/onnx_pretrained_ctc_HLG_streaming.py \ + --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + --words $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + $repo/test_wavs/0.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. + +Note: HLG.fst is generated directly from ../local/prepare_lang_fst.py +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import k2 +import kaldifst +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + required=True, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zapformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + logging.info(f"Resample {sample_rate} to {expected_sample_rate}") + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + word_table = k2.SymbolTable.from_file(args.words) + model = OnnxModel(model_filename=args.nn_model) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.init_decoding() + + chunk = int(1 * sample_rate) # 1 second + start = 0 + + n = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + + # simulate streaming + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + + log_probs = model(frames) + log_probs = log_probs.squeeze(0).cpu().numpy() + + decodable = DecodableCtc(log_probs, offset=n) + n += log_probs.shape[0] + + num_processed_frames += offset + decoder.advance_decoding(decodable) + + if not decoder.reached_final(): + logging.info(f"Failed to decode {args.sound_file}") + return + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + + if not ok: + logging.info(f"Failed to get linear symbol sequence for {args.sound_file}") + return + + hyps = " ".join([word_table[i] for i in osymbols_out]).lower() + logging.info(f"\n{args.sound_file}\n{hyps}") + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/optim.py b/egs/librispeech/ASR/zapformer/optim.py deleted file mode 120000 index 207eecfcda..0000000000 --- a/egs/librispeech/ASR/zapformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/pretrained.py b/egs/librispeech/ASR/zapformer/pretrained.py deleted file mode 120000 index 70ad71ffc6..0000000000 --- a/egs/librispeech/ASR/zapformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/pretrained.py b/egs/librispeech/ASR/zapformer/pretrained.py new file mode 100755 index 0000000000..3dc98085ec --- /dev/null +++ b/egs/librispeech/ASR/zapformer/pretrained.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zapformer/exp/epoch-xx.pt`. + +Note: ./zapformer/exp/pretrained.pt is generated by ./zapformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/pretrained_ctc.py b/egs/librispeech/ASR/zapformer/pretrained_ctc.py deleted file mode 120000 index fb9bdf1fa2..0000000000 --- a/egs/librispeech/ASR/zapformer/pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/pretrained_ctc.py b/egs/librispeech/ASR/zapformer/pretrained_ctc.py new file mode 100755 index 0000000000..2cbd4098a9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/pretrained_ctc.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +(1) ctc-decoding +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(5) attention-decoder-rescoring-no-ngram +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method attention-decoder-rescoring-no-ngram \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from ctc_decode import get_decoding_params +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a token table, + i.e., lang_dir/tokens.txt, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + nbest n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + whole-lattice n-gram LM rescoring. + (4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank + params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] + assert params.blank_id == 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = ctc_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i].item() // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: + max_token_id = params.vocab_size - 1 + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + else: + logging.info("Use attention decoder rescoring without ngram") + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + + token_ids = get_texts(best_path) + hyps = [[token_table[i] for i in ids] for ids in token_ids] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/scaling.py b/egs/librispeech/ASR/zapformer/scaling.py deleted file mode 120000 index 58e4b0a0fe..0000000000 --- a/egs/librispeech/ASR/zapformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/scaling.py b/egs/librispeech/ASR/zapformer/scaling.py new file mode 100644 index 0000000000..06cb538627 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/scaling.py @@ -0,0 +1,1295 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import copy +import random +from typing import Optional, Tuple, Union, Any + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + + + +class FloatLike: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class ScheduledFloat: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class SimpleOrthogonalLinear: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class PiecewiseLinear: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class CosineSimilarityLoss: # TODO: remove. this is to solve problems with multiple jobs running. + pass +class PredictLoss: # TODO: remove. this is to solve problems with multiple jobs running. + pass +get_max_similarity = None # TODO: remove. this is to solve problems with multiple jobs running. + + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.get_autocast_gpu_dtype()) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.amp.autocast('cuda', enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + + +# all arg tensors are scalars. +def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): + stats = (x ** 2).mean(dim=2, keepdim=True) + T = x.shape[0] # time + if mask is None: + stats = stats.sum(dim=0) + lengths = T + else: + mask = (~mask).to(torch.float).t().unsqueeze(-1) + stats = stats * mask + stats = stats.sum(dim=0) + lengths = mask.sum(dim=0) + + scales = (lengths / stats).sqrt() + assert scales.shape == (x.shape[1], 1) + return x * ((scale * scales) + offset) + +# all arg tensors are scalars. +# mask only used in non-causal mode; ballast_rms and ballast_frames only used in causal mode. +def _causal_sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor): + stats = (x ** 2).mean(dim=2, keepdim=True) + + # no need for mask in causal mode. + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * ballast_frames.abs() + ballast = ballast_frames * (ballast_rms ** 2) + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + ballast + lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + scales = (lengths / stats).sqrt() + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset) + + +# all arg tensors are scalars +def _causal_sequence_norm_streaming( + x: Tensor, + offset: Tensor, + scale: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Streaming inference forward for _sequence_norm. We assume that ballast_frames and ballast_rms + are already included in cached_stats_sum and cached_len. + + Args: + x: (seq_len, batch_size, channels) + offset: scalar + scale: scalar + cached_stats_sum: (batch_size,) + cached_len: (batch_size,) + + Returns: + - normalized x, (seq_len, batch_size, channels) + - updated cached_stats_sum, (batch_size,) + - updated cached_len, (batch_size,) + """ + stats = (x ** 2).mean(dim=2, keepdim=True) # (seq_len, batch_size, 1) + + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + cached_stats_sum.unsqueeze(-1) + lengths = cached_len[:, None] + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + # update cached_stats_sum and cached_len for the next chunk + cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) + cached_len = cached_len + T + + scales = (lengths / stats).sqrt() # (T, batch_size, 1) + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset), cached_stats_sum, cached_len + + +class CausalSequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + ballast_rms: Tensor, + ballast_frames: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) + + return _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors + + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + ballast_rms = ballast_rms.to(torch.float32).detach().requires_grad_() + ballast_frames = ballast_frames.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad) + +class SequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + mask: Optional[Tensor], + ) -> Tensor: + ctx.save_for_backward(x, offset, scale) + ctx.mask = mask + + return _sequence_norm(x, offset, scale, mask) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, offset, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _sequence_norm(x, offset, scale, ctx.mask) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x if x is None else x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), None + + +class CausalSequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + up to the current point as well as the channels, with some padding of the stats with "default values" + determined by ballast_frames, ballast_rms for robustness near the beginning of the sequence. + + There is also a learnable scalar scale, multiplicatively applied to the output, and a learnable + "offset" value that acts multiplicatively on the input without taking into account the rms values. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + + # ballast_mean: assumed rms value of ballast frames used to pad stats + self.ballast_rms = nn.Parameter(torch.tensor(0.1)) + # ballast_frames: number of ballast frames, in hundreds (will be multiplied by 100) + self.ballast_frames = nn.Parameter(torch.tensor(0.05)) # number of ballast frames, will be multiplied by 100 + self.name = None + + def forward(self, x: Tensor, _mask: Optional[Tensor] = None) -> Tensor: + # x: (seq, batch, channel) + # The mask is ignored, it is allowed only for consistency of interface with SequenceNorm. + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _causal_sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ballast_rms = limit_param_value( + self.ballast_rms, min=0.0, max=10.0, training=self.training) + + ballast_frames = limit_param_value( + self.ballast_frames, min=0.0, max=5.0, training=self.training) # max of 5.0 would be 500 frames + + ans = CausalSequenceNormFunction.apply( + x, offset, scale, ballast_rms, ballast_frames, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}, ballast_rms={self.ballast_rms.item()}, ballast_frames*100={100*self.ballast_frames.item()}") + + return ans + + @torch.jit.export + def get_init_cache(self, batch_size: int): + """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. + """ + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * self.ballast_frames.abs() + ballast = ballast_frames * (self.ballast_rms ** 2) + + cached_stats_sum = ballast.unsqueeze(0).repeat(batch_size) # (batch_size,) + cached_len = ballast_frames.unsqueeze(0).repeat(batch_size) # (batch_size,) + + return cached_stats_sum, cached_len + + def streaming_forward( + self, + x: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + + x, cached_stats_sum, cached_len = _causal_sequence_norm_streaming( + x, self.offset, self.scale, cached_stats_sum, cached_len) + return x, cached_stats_sum, cached_len + + +class SequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + as well as the channels; and a padding mask is used for irregular length sequences (actually, + the mask is applied multiplicatively as well.) + + There is also a learnable scalar scale and a learnable "offset" value. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + self.name = None + + def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + # x: (seq, batch, channel) + # mask: bool, (batch_size, seq_len) + # Note: mask is ignored in causal mode. + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _sequence_norm(x, self.offset, self.scale, mask) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ans = SequenceNormFunction.apply( + x, offset, scale, mask, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") + + return ans + + + +# assume layout: (time, batch, channel) +def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) + scales = scale / x_sq.sqrt() + return x * scales + + + +class RmsNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + eps: Tensor, + scale: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, eps, scale) + return _rms_norm(x, eps, scale) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, eps, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, eps, scale = x.detach(), eps.detach(), scale.detach() + + x.requires_grad = True + eps.requires_grad = True + scale.requires_grad = True + + with torch.enable_grad(): + ans = _rms_norm(x, eps, scale) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode. + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(eps.grad), c(scale.grad) + + +class RmsNorm(torch.nn.Module): + """ + This is like RMSNorm with a trainable scale. + + """ + def __init__( + self, + ) -> None: + super(RmsNorm, self).__init__() + self.scale = nn.Parameter(torch.tensor(0.2)) # output scale + self.eps = nn.Parameter(torch.tensor(0.1)) + self.name = None + + + def forward(self, x: Tensor) -> Tensor: + # Assumes layout is (time, batch, channel) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _rms_norm(x, self.eps, self.scale) + + scale = limit_param_value( + self.scale, min=0.05, max=1.0, training=self.training) + + eps = limit_param_value( + self.eps, min=0.0, max=10.0, training=self.training) + + ans = RmsNormFunction.apply( + x, eps, scale, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, eps={eps.item()}, scale={scale.item()}") + + return ans + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.01 * initial_scale, 0.01 * initial_scale) + return ans + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class OrthogonalPenaltyFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, weight: Tensor, penalty_scale: float, name: str): + ctx.save_for_backward(weight) + ctx.name = name + ctx.penalty_scale = penalty_scale + return weight + + @staticmethod + @custom_bwd + def backward(ctx, weight_grad): + weight, = ctx.saved_tensors + + if weight.requires_grad and ctx.penalty_scale != 0.0: + penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() + + with torch.enable_grad(): + weight = weight.detach() + weight.requires_grad = True + + # Compute symmetric matrix-product prod with the smallest + # dimension possible given the shape of w. This is not just for + # efficiency; if we computed it the wrong way round, the product + # would have deficient rank and could never be the identity. + if (weight.shape[0] > weight.shape[1]): + prod = torch.matmul(weight.t(), weight) + else: + prod = torch.matmul(weight, weight.t()) + + # we'll try to enforce that for any i, prod[i] is any constant times the identity. + + # in the loss-function: + # orthogonality_loss = ((prod - I) ** 2).sum(), + + # note, prod_diag shares memory with prod, this will matter later on. + (r, c) = prod.shape + (r_stride, c_stride) = prod.stride() + + def diag_inplace(z): + return torch.as_strided(z, size=(r,), stride=(r_stride+c_stride,)) + + diag_inplace(prod)[:] -= 1. + + # that loss that we want to backprop would be 0.5 * (prod ** + # 2).sum() * penalty_scale. we can backprop this without doing + # any reductions as follows: + prod.backward(gradient=prod * penalty_scale) + + + do_print = random.random() < 0.002 + if do_print: + # we print a normalized version of the loss, by dividing by the + # number of rows. + loss = (prod ** 2).mean() + logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") + + + # add the extra gradient term from the orthogonality loss. + weight_grad = weight_grad + weight.grad + return weight_grad, None, None + +class OrthogonalLinear(nn.Linear): + """ + Like nn.Linear but can enforce that the weight matrix is orthogonal; in the non-square + case this is interpreted as either M^T M == I or M M^T == I, whichever would give a smaller + dimension. + (If M is square, these definitions are equivalent and is equivalent to the normal + definition of orthogonal). + + Args: + in_channels: number of input channels + out_channels: number of output channels + lr_scale: we will scale the weight by this value before applying the orthogonal + constraint and using it; with most optimizers + this will have the effect of slowing down the learning by this factor because + the parameter value will be larger. + bias: if True, include a bias term. + penalty_scale: a scale on the penalty on non-orthogonality (this will + be multiplied by the average-absolute-value of the + backpropagated gradient). + """ + # if in_groups or out_groups are set to >1, the orthogonal constraint + # will be set per group. both of them cannot be >1. + def __init__(self, + in_channels: int, + out_channels: int, + lr_scale: float = 1.0, + bias: bool = True, + penalty_scale: float = 20.0, + ): + super().__init__(in_channels, out_channels, bias=bias) + self.name = None + self.penalty_scale = copy.deepcopy(penalty_scale) + self.lr_scale = lr_scale + + with torch.no_grad(): + self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * (1. / lr_scale) + if self.bias is not None: + torch.nn.init.uniform_(self.bias, -0.01, 0.01) + + + def forward(self, x: Tensor, transpose: bool = False): + # you can only use transpose=True if you used bias=False in initialization + weight = self.weight + lr_scale = self.lr_scale + if lr_scale != 1.0: + weight = weight * lr_scale + if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): + weight = OrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) + + if transpose: + weight = weight.t() + return torch.nn.functional.linear(x, weight, self.bias) + + +class ScaleLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, max_rms: float, aux_loss_scale: float, name: str): + ctx.save_for_backward(x) + ctx.max_rms = max_rms + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + x, = ctx.saved_tensors + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + rms = (x ** 2).mean(dim=-1).sqrt() + numel = rms.numel() + + excess = (rms / ctx.max_rms - 1.).relu().mean() + + if random.random() < 0.002: + logging.info( + f"ScaleLimiter: name={ctx.name}, max_rms={ctx.max_rms}, " + f"rms={rms.mean().item()}, excess={excess.item()}, " + f"loss_scale={ctx.aux_loss_scale}" + ) + excess.backward(gradient=torch.full_like(excess, ctx.aux_loss_scale * numel)) + return x_grad + x.grad, None, None, None + + +class ScaleLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if the norm of any activation vector is less than min_rms + or more than max_rms. + + Assumes channel dim is -1 and the input shape has >1 dimension. + """ + def __init__(self, max_rms: float): + super().__init__() + self.name = None + self.max_rms = max_rms + + + def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return _no_op(x) + else: + return ScaleLimiterFunction.apply(x, float(self.max_rms), + aux_loss_scale, self.name) + + +class CorrelationLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): + ctx.save_for_backward(x) + ctx.mask = mask + ctx.limit = limit + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 + x, = ctx.saved_tensors + mask = ctx.mask + aux_loss_scale = ctx.aux_loss_scale + (batch_size, seq_len, num_channels) = x.shape + + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + x_orig = x + + def norm(x: Tensor): + eps = 1.0e-20 + return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() + x = norm(x) + + if mask is not None: + mask = (~mask).to(x.dtype).unsqueeze(-1) + x = x * mask + + half_batch = batch_size // 2 + if half_batch <= 1: + # the reason we also return None if half_batch==1 is because of CR-CTC + # where they may really be duplicates + return None, None, None, None, None + + + #x = torch.cat((x, y), dim=-1) + C = x.shape[-1] # num_channels + x1, x2 = x[0::2], x[1::2] + x1 = x1.reshape(-1, C) + x2 = x2.reshape(-1, C) + + if mask is not None: + numel1 = mask[0::2].sum() + numel2 = mask[1::2].sum() + else: + numel1 = x1.shape[0] + numel2 = x2.shape[0] + + S1 = torch.matmul(x1.t(), x1) * (1. / numel1) + S2 = torch.matmul(x2.t(), x2) * (1. / numel2) + + # S1, S2: (N, N) where N = min(num_channels, max_channels) + correlation = (S1 * S2).mean() + loss = (correlation - ctx.limit).relu() + + if random.random() < 0.0001: + logging.info( + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, loss={loss}" + ) + + loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) + + + return x_orig.grad, None, None, None, None + + +class CorrelationLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if the input feature has a covariance matrix that is + too different from the identity matrix. limit=1/num_channels is the + smallest limit you can provide but the limit should be much larger than + this, like 1/sqrt(num_channels). + + Assumes input is (batch, seq, channel) + """ + def __init__(self, limit: float = 0.03): + super().__init__() + self.name = None + self.limit = limit + + + def forward(self, x: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: + # x should be: (batch, seq, channel) + # returns a scalar tensor that should be included in the loss function with: + # z = with_loss(z, ret, None) + # where z is any quantity that will be used in calculating the main loss. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return torch.tensor(0.0, device=x.device) + else: + return CorrelationLimiterFunction.apply(x, + aux_loss_scale, + float(self.limit), + mask, + self.name) + + + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + ctx.dtype = y.dtype + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ctx.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name=None): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + + + + + +def torch_compile(fn, *args, **kwargs): + if hasattr(torch, 'compile'): + fn = torch.compile(fn, *args, **kwargs, dynamic=True, options={"shape_padding": True, "force_shape_pad": True}) + return fn + +def swashl(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 + +def swashr(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 + + +def swashl_and_deriv(x: Tensor): + x_offset = 4. * x - 4. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.00875 + return y, deriv + +def swashr_and_deriv(x: Tensor): + x_offset = 4. * x - 1. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.07831542175 + return y, deriv + + +class SwashL(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashl) + def forward(self, x: Tensor) -> Tensor: + """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + return self.func(x) + +class SwashR(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashr) + def forward(self, x: Tensor) -> Tensor: + """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + return self.func(x) + + + +class ActivationAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + forward_func: Any, + backward_func: Any, + ): + ctx.save_for_backward(x, weight, bias) + + ctx.backward_func = backward_func + + x = forward_func(x) + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias) = saved + + y, func_deriv = ctx.backward_func(x) + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + return x_deriv, weight_deriv, bias_deriv, None, None + + + +class ActivationAndLinear(torch.nn.Module): + """ + This merges an activation function followed by a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwashL, this will be + equivalent to: + nn.Sequential(SwashL(), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwashL, SwashR. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwashL", + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", l.bias) + + self.activation = activation + + assert activation in ["SwashL", "SwashR"] + if activation == "SwashL": + self.forward_func = torch_compile(swashl) + self.backward_func = torch_compile(swashl_and_deriv) + else: + self.forward_func = torch_compile(swashr) + self.backward_func = torch_compile(swashr_and_deriv) + + + def forward(self, x: Tensor): + if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + x = self.forward_func(x) + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.forward_func, + self.backward_func, + ) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + + +def _test_swashl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwashL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swashr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwashR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_activation_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + if True: + for activation in ["SwashL", "SwashR"]: + m1 = nn.Sequential( + SwashL() if activation == "SwashL" else SwashR(), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + ) + with torch.no_grad(): + m2.weight[:] = m1[1].weight + if bias: + m2.bias[:] = m1[1].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + print("grad1 = ", m1[1].weight.grad) + print("grad2 = ", m2.weight.grad) + + assert torch.allclose(m1[1].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[1].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwashL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +def _test_orthogonal_linear(): + m = OrthogonalLinear(128, 128) + m(torch.randn(30, 2, 128)) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() + _test_swashr_deriv() + _test_swashl_deriv() + _test_activation_and_linear() + _test_orthogonal_linear() diff --git a/egs/librispeech/ASR/zapformer/scaling_converter.py b/egs/librispeech/ASR/zapformer/scaling_converter.py deleted file mode 120000 index bc7c7b5e37..0000000000 --- a/egs/librispeech/ASR/zapformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/scaling_converter.py b/egs/librispeech/ASR/zapformer/scaling_converter.py new file mode 100644 index 0000000000..e4ee960838 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/scaling_converter.py @@ -0,0 +1,99 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file replaces various modules in a model. +Specifically, Whiten is replaced with an identity operator. +""" + +import copy +from typing import List + +import torch +import torch.nn as nn +from scaling import ( + SwashL, + SwashLOnnx, + SwashR, + SwashROnnx, +) +from zapformer import RelPosScores + + +# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa +# get_submodule was added to nn.Module at v1.9.0 +def get_submodule(model, target): + if target == "": + return model + atoms: List[str] = target.split(".") + mod: torch.nn.Module = model + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) + mod = getattr(mod, item) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not " "an nn.Module") + return mod + + +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, + is_pnnx: bool = False, + is_onnx: bool = False, +): + """ + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. + is_onnx: + True if we are going to export the model for ONNX. + Return: + Return a model without scaled layers. + """ + if not inplace: + model = copy.deepcopy(model) + + d = {} + for name, m in model.named_modules(): + if isinstance(m, (Dropout3, ScaleGrad, Whiten)): + d[name] = nn.Identity() + elif is_onnx and isinstance(m, SwashR): + d[name] = SwashROnnx() + elif is_onnx and isinstance(m, SwashL): + d[name] = SwashLOnnx() + elif is_onnx and isinstance(m, RelPosScores): + # We want to recreate the positional encoding vector when + # the input changes, so we have to use torch.jit.script() + # to replace torch.jit.trace() + d[name] = torch.jit.script(m) + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(get_submodule(model, parent), child, v) + else: + setattr(model, k, v) + + return model diff --git a/egs/librispeech/ASR/zapformer/streaming_beam_search.py b/egs/librispeech/ASR/zapformer/streaming_beam_search.py deleted file mode 120000 index 97e6e733f2..0000000000 --- a/egs/librispeech/ASR/zapformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/streaming_beam_search.py b/egs/librispeech/ASR/zapformer/streaming_beam_search.py new file mode 100644 index 0000000000..3c8565b330 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/streaming_beam_search.py @@ -0,0 +1,295 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import k2 +import torch +import torch.nn as nn +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from decode_stream import DecodeStream + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + blank_penalty: float = 0.0, +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + num_active_paths: int = 4, + blank_penalty: float = 0.0, +) -> None: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + num_active_paths: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + streams: List[DecodeStream], + beam: float, + max_states: int, + max_contexts: int, + blank_penalty: float = 0.0, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first generated by Fsa-based beam search, then we get the + recognition by applying shortest path on the lattice. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + streams: + A list of stream objects. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + B, T, C = encoder_out.shape + assert B == len(streams) + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyp_tokens[i] diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py index a04ed04adf..400f7804ce 100755 --- a/egs/librispeech/ASR/zapformer/streaming_decode.py +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -19,13 +19,13 @@ """ Usage: -./zipformer/streaming_decode.py \ +./zapformer/streaming_decode.py \ --epoch 28 \ --avg 15 \ --causal 1 \ --chunk-size 32 \ --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./zapformer/exp \ --decoding-method greedy_search \ --num-decode-streams 2000 """ @@ -126,7 +126,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="zapformer/exp", help="The experiment dir", ) @@ -247,14 +247,14 @@ def get_init_states( def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances + """Stack list of zapformer states that correspond to separate utterances into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. + zapformer when those utterances are formed into a batch. Args: state_list: Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, + of the zapformer model for a single utterance. For element-n, state_list[n] is a list of cached tensors of all encoder layers. For layer-i, state_list[n][i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len). @@ -313,7 +313,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances + """Unstack the zapformer state corresponding to a batch of utterances into a list of states, where the i-th entry is the state from the i-th utterance in the batch. @@ -322,7 +322,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: Returns: state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. + of the zapformer model for a single utterance. """ assert (len(batch_states) - 2) % 5 == 0, len(batch_states) tot_num_layers = (len(batch_states) - 2) // 5 diff --git a/egs/librispeech/ASR/zapformer/subsampling.py b/egs/librispeech/ASR/zapformer/subsampling.py deleted file mode 120000 index d178adc2e5..0000000000 --- a/egs/librispeech/ASR/zapformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/subsampling.py b/egs/librispeech/ASR/zapformer/subsampling.py new file mode 100644 index 0000000000..3ec098bc20 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/subsampling.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings +from typing import Tuple, Optional + +import torch +from zipformer_modules import ( + ScaledLinear, + SwashL, + SwashR, +) +from torch import Tensor, nn + + +class AddNoise(nn.Module): + # assume Conv2d-style input: (N, C, H, W) + def __init__(self, rel_noise_scale: float): + super().__init__() + self.rel_noise_scale = rel_noise_scale + + def forward(self, x: Tensor) -> Tensor: + if not self.training: + return x + eps = 3.0e-08 + noise_scale = ((x ** 2).mean(dim=(1,2,3), keepdim=True) + eps).sqrt() * self.rel_noise_scale + return x + noise_scale * torch.randn_like(x) + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + causal: bool = False, + ): + super().__init__() + assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 + self.causal = causal + hidden_channels = channels * hidden_ratio + + if not causal: + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + else: + padding = (0, kernel_size[1] // 2) + self.left_pad = kernel_size[0] - 1 + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1, + ) + + self.activation = SwashL() + + self.pointwise_conv2 = nn.Conv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + ) + + def forward( + self, x: Tensor, + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + + if self.causal: + x = nn.functional.pad(x, (0, 0, self.left_pad, 0)) + x = self.depthwise_conv(x) + assert x.shape == bypass.shape, (x.shape, bypass.shape) + + x = self.pointwise_conv1(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + cache: (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - The returned value has the same shape as x. + - Updated cache. + """ + bypass = x + + # Pad left side with cache, and update cache + assert cache.size(2) == self.left_pad + x = torch.cat([cache, x], dim=2) + cache = x[:, :, -self.left_pad :, :] + + x = self.depthwise_conv(x) + assert x.shape == bypass.shape, (x.shape, bypass.shape) + + x = self.pointwise_conv1(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + + return x, cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 16, + layer2_channels: int = 64, + layer3_channels: int = 128, + causal: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + self.in_channels = in_channels + super().__init__() + # The AddNoise module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # getting too large. The justification in my mind for why this might work + # is that the first Conv2d module increases the dimension of the input quite a bit, + # so its output lives in a linear subspace; and there may in principle be quite large gradients + # in directions not in this subspace, without affecting the model quality. + # so by adding a little noise we force the model to "ignore" directions not in this subspace, + # as much as possible, which will tend to avoid very large gradients. The reason the + # large gradients are a problem is because of float16 training with GradScaler, the infinities will + # be detected and will make it scale the grads by a smaller amount.. + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + AddNoise(rel_noise_scale=5.0e-03), # this AddNoise + SwashR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + SwashR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + SwashR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7), causal=causal) + + # (in_channels-3)//4 + self.out_width = (((in_channels - 1) // 2) - 1) // 2 + self.layer3_channels = layer3_channels + + # scale it up a bit, else the output is quite small. + self.out = ScaledLinear(self.out_width * layer3_channels, out_channels) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, (T-7)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, (T-7)//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, (T-7)//2, odim) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens - 7) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + + key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) + + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) + + return 0.15 * x, x_lens + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + cache: Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + cache: + The cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - a tensor of shape (N, (T-7)//2, odim) + - output lengths, of shape (batch_size,) + - updated cache + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + + # T' = (T-7)//2 + x = self.conv(x) + + x, cache = self.convnext.streaming_forward(x, cache=cache) + + # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, T', out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, T', odim) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens - 7) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) + + return 0.15 * x, x_lens, cache + + @torch.jit.export + def get_init_cache( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + """Get initial states for Conv2dSubsampling module. + It is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + """ + left_pad = self.convnext.left_pad + freq = self.out_width + channels = self.layer3_channels + cache = torch.zeros(batch_size, channels, left_pad, freq, device=device) + + return cache + + +def _test_conv2d_subsampling_streaming(): + logging.info("Testing Conv2dSubsampling streaming equivalence...") + + batch_size = 2 + idim = 80 + odim = 256 + + model = Conv2dSubsampling( + in_channels=idim, + out_channels=odim, + causal=True + ) + + model.eval() + + out_chunk_size = 32 + in_chunk_size = out_chunk_size * 2 + 7 + in_shift = out_chunk_size * 2 + + num_chunks = 10 + + seq_len = num_chunks * in_shift + 7 + + x_full = torch.randn(batch_size, seq_len, idim) + x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) + + with torch.no_grad(): + out_full, out_lens_full = model(x_full, x_lens_full) + + cache = model.get_init_cache(batch_size=batch_size) + + out_chunks = [] + out_offset = 0 + + for i in range(num_chunks): + start = i * in_shift + end = start + in_chunk_size + x_chunk = x_full[:, start:end, :] + x_lens_chunk = torch.full((batch_size,), in_chunk_size, dtype=torch.int64) + + out_chunk, out_lens_chunk, cache = model.streaming_forward( + x_chunk, x_lens_chunk, cache + ) + out_chunks.append(out_chunk) + + out_chunk_len = out_chunk.shape[1] + expected_out = out_full[:, out_offset : out_offset + out_chunk_len, :] + + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) + logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") + + assert torch.allclose(expected_out, out_chunk, atol=1e-4), f"Chunk {i+1} mismatch! max diff: {diff_chunk}" + out_offset += out_chunk_len + + out_stream_cat = torch.cat(out_chunks, dim=1) + diff_total = torch.max(torch.abs(out_full - out_stream_cat)) + logging.info(f"Total Max Diff between full forward and streaming: {diff_total}") + assert torch.allclose(out_full, out_stream_cat, atol=1e-4), "Total outputs do not match!" + + logging.info("Passed") + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_conv2d_subsampling_streaming() diff --git a/egs/librispeech/ASR/zapformer/test_subsampling.py b/egs/librispeech/ASR/zapformer/test_subsampling.py deleted file mode 120000 index 2925ea3c51..0000000000 --- a/egs/librispeech/ASR/zapformer/test_subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/test_subsampling.py b/egs/librispeech/ASR/zapformer/test_subsampling.py new file mode 100755 index 0000000000..b502d5a773 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/test_subsampling.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 + +import torch +from subsampling import Conv2dSubsampling + +# TODO: fix, this does not work right tnow +def test_conv2d_subsampling(): + layer1_channels = 8 + layer2_channels = 32 + layer3_channels = 128 + + out_channels = 192 + encoder_embed = Conv2dSubsampling( + in_channels=80, + out_channels=out_channels, + layer1_channels=layer1_channels, + layer2_channels=layer2_channels, + layer3_channels=layer3_channels, + ) + N = 2 + T = 200 + num_features = 80 + x = torch.rand(N, T, num_features) + x_copy = x.clone() + + x = x.unsqueeze(1) # (N, 1, T, num_features) + + x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1) + assert x.shape == (N, layer1_channels, T - 2, num_features) + # (2, 8, 198, 80) + + x = encoder_embed.conv[1](x) # scale grad + x = encoder_embed.conv[2](x) # balancer + x = encoder_embed.conv[3](x) # swooshR + + x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2 + assert x.shape == ( + N, + layer2_channels, + ((T - 2) - 3) // 2 + 1, + (num_features - 3) // 2 + 1, + ) + # (2, 32, 98, 39) + + x = encoder_embed.conv[5](x) # balancer + x = encoder_embed.conv[6](x) # swooshR + + # conv2d: + # in 32, out 128, kernel 3, stride (1, 2) + x = encoder_embed.conv[7](x) + assert x.shape == ( + N, + layer3_channels, + (((T - 2) - 3) // 2 + 1) - 2, + (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + ) + # (2, 128, 96, 19) + + x = encoder_embed.conv[8](x) # balancer + x = encoder_embed.conv[9](x) # swooshR + + # (((T - 2) - 3) // 2 + 1) - 2 + # = (T - 2) - 3) // 2 + 1 - 2 + # = ((T - 2) - 3) // 2 - 1 + # = (T - 2 - 3) // 2 - 1 + # = (T - 5) // 2 - 1 + # = (T - 7) // 2 + assert x.shape[2] == (x_copy.shape[1] - 7) // 2 + + # (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + # = ((num_features - 3) // 2 + 1 - 3) // 2 + 1, + # = ((num_features - 3) // 2 - 2) // 2 + 1, + # = (num_features - 3 - 4) // 2 // 2 + 1, + # = (num_features - 7) // 2 // 2 + 1, + # = (num_features - 7) // 4 + 1, + # = (num_features - 3) // 4 + assert x.shape[3] == (x_copy.shape[2] - 3) // 4 + + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # Input shape to convnext is + # + # (N, layer3_channels, (T-7)//2, (num_features - 3)//4) + + # conv2d: in layer3_channels, out layer3_channels, groups layer3_channels + # kernel_size 7, padding 3 + x = encoder_embed.convnext.depthwise_conv(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1 + x = encoder_embed.convnext.pointwise_conv1(x) + assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4) + + x = encoder_embed.convnext.hidden_balancer(x) # balancer + x = encoder_embed.convnext.activation(x) # swooshL + + # conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1 + x = encoder_embed.convnext.pointwise_conv2(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # bypass and layer drop, omitted here. + x = encoder_embed.convnext.out_balancer(x) + + # Note: the input and output shape of ConvNeXt are the same + + x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1) + assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4)) + + x = encoder_embed.out(x) + assert x.shape == (N, (T - 7) // 2, out_channels) + + x = encoder_embed.out_whiten(x) + x = encoder_embed.out_norm(x) + # final layer is dropout + + # test streaming forward + + subsampling_factor = 2 + cached_left_padding = encoder_embed.get_init_states(batch_size=N) + depthwise_conv_kernel_size = 7 + pad_size = (depthwise_conv_kernel_size - 1) // 2 + + assert cached_left_padding.shape == ( + N, + layer3_channels, + pad_size, + (num_features - 3) // 4, + ) + + chunk_size = 16 + right_padding = pad_size * subsampling_factor + T = chunk_size * subsampling_factor + 7 + right_padding + x = torch.rand(N, T, num_features) + x_lens = torch.tensor([T] * N) + y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward( + x, x_lens, cached_left_padding + ) + + assert y.shape == (N, chunk_size, out_channels), y.shape + assert next_cached_left_padding.shape == cached_left_padding.shape + + assert y.shape[1] == y_lens[0] == y_lens[1] + + +def main(): + test_conv2d_subsampling() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 3ac102c240..b91b52fdf9 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -24,22 +24,22 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model training: -./zipformer/train.py \ +./zapformer/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ + --exp-dir zapformer/exp \ --full-libri 1 \ --max-duration 1000 # For streaming model training: -./zipformer/train.py \ +./zapformer/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ + --exp-dir zapformer/exp \ --causal 1 \ --full-libri 1 \ --max-duration 1000 @@ -95,7 +95,7 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 +from zapformer import Zapformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -107,7 +107,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module +from alternating_spec_augment import AlternatingSpecAugment # using this, not lhotse's version of nn.Module from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -183,7 +183,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--num-encoder-layers", type=str, default="6,8,14,8", - help="Number of zipformer encoder layers per stack, comma separated.", + help="Number of zapformer encoder layers per stack, comma separated.", ) parser.add_argument( @@ -204,7 +204,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--embed-multiple", type=int, default=6, - help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zapformer stacks and zapformer output dim.", ) parser.add_argument( @@ -218,7 +218,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--num-heads", type=str, default="4", - help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", + help="Number of attention heads in the zapformer encoder layers, per stack: a single int or comma-separated list.", ) parser.add_argument( @@ -418,7 +418,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="zapformer/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -644,7 +644,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 10000, - # parameters for zipformer + # parameters for zapformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. # parameters for attention-decoder @@ -680,7 +680,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( + encoder = Zapformer( input_dim=lookup(params, "embed_dim"), output_downsampling_factor=2, downsampling_factor=lookup(params, "downsampling_factor"), @@ -942,7 +942,7 @@ def compute_loss( params: Parameters for training. See :func:`get_params`. model: - The model for training. It is an instance of Zipformer in our case. + The model for training. It is an instance of Zapformer in our case. batch: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. @@ -1483,7 +1483,7 @@ def remove_short_and_long_utt(c: Cut): # where T is the number of feature frames after subsampling # and S is the number of tokens in the utterance - # In ./zipformer.py, the conv module uses the following expression + # In ./zapformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 7) // 2 + 1) // 2 tokens = sp.encode(c.supervisions[0].text, out_type=str) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py new file mode 100644 index 0000000000..0539c6e990 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -0,0 +1,2078 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from zapformer_modules import ( + ActivationAndLinear, + CausalSequeneNorm, + CorrelationLimiter, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + OrthogonalLinear, + RmsNorm, + SequenceNorm, + OrthogonalLinear, + ScaledLinear, # just an initializer for Linear + SwashR, + ScaleLimiter, +) +from zapformer_utils import ( + limit_param_value, + penalize_abs_values_gt, + softmax, + with_loss, +) + + +from torch import Tensor, nn + +from icefall.utils import make_pad_mask + + +class Zapformer(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of position-embedding in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules + conv_params (int or Tuple[int])): Kernel size of convolution module + + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. + """ + def __init__( + self, + input_dim: int, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + query_head_dim: Union[int, Tuple[int]] = 64, + value_head_dim: Union[int, Tuple[int]] = 12, + pos_head_dim: Union[int, Tuple[int]] = 4, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_multiple: Union[int, Tuple[int]] = 4, + conv_params: Union[int, Tuple[int]] = 31, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zapformer, self).__init__() + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + self.pos_head_dim = pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_multiple = _to_tuple(feedforward_multiple) + self.conv_params = conv_params = _to_tuple(conv_params) + + self.causal = causal + self.chunk_size = (chunk_size,) if isinstance(chunk_size, int) else chunk_size + self.left_context_frames = (left_context_frames,) if isinstance(left_context_frames, int) else left_context_frames + + # each one will be ZapformerEncoder or OrthogonalDownsample or OrthogonalUpsample + encoders = [] + + num_encoders = len(downsampling_factor) + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + + + for i in range(num_encoders): + encoder_layer = ZapformerEncoderLayer( + embed_dim=encoder_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + pos_head_dim=pos_head_dim[i], + feedforward_multiple=feedforward_multiple[i], + conv_params=conv_params[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZapformerEncoder( + encoder_layer, + num_encoder_layers[i], + dim=downsampling_factor[i]*input_dim, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.out_norm = RmsNorm() + + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + Returns: + Return (embeddings_lengths), where: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + chunk_size, left_context_chunks = self.get_chunk_info() + orig_seq_len = x.shape[0] + + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), + dim=0) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) + + num_stacks = len(self.downsampling_factor) + + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = downsample_by(x, ds) + T = x.shape[0] + x = module( + x, + chunk_size=chunk_size // ds if chunk_size > 0 else -1, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) + ) + x = upsample_by(x, ds) + + od = self.output_downsampling_factor + x = downsample_by(x, od) + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if od > 1: + x_lens = (x_lens + od - 1) // od + + x = self.out_norm(x) + + return x, x_lens + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.conv_params[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) # TODO: could test remove this + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + caches: List[Tensor], + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + caches: list of cached tensors of all encoder layers. For layer-i, + caches[i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated caches: an updated list of cache tensors. + """ + orig_seq_len = x.shape[0] + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), dim=0) + + if src_key_padding_mask is not None: + left_context_frames = src_key_padding_mask.shape[1] - orig_seq_len + assert left_context_frames == self.left_context_frames[0] + if pad > 0: + src_key_padding_mask = torch.cat( + (src_key_padding_mask[:, :left_context_frames], + pad_mask(src_key_padding_mask[:, left_context_frames:], x.shape[0])), + dim=1, + ) + + new_caches = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x = downsample_by(x, ds) + + # Slice out the specific caches for the current module + module_caches = caches[layer_offset * 5 : (layer_offset + num_layers) * 5] + + x, new_module_caches = module.streaming_forward( + src=x, + caches=module_caches, + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + ) + + layer_offset += num_layers + new_caches.extend(new_module_caches) + + x = upsample_by(x, ds) + + # Output downsampling and normalization + od = self.output_downsampling_factor + x = downsample_by(x, od) + + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if od > 1: + x_lens = (x_lens + od - 1) // od + + x = self.out_norm(x) + + return x, x_lens, new_caches + + @torch.jit.export + def get_init_caches( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial caches. + + A list of cached tensors of all encoder layers. For layer-i, caches[i*5:(i+1)*5] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len). + """ + caches = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + conv_left_pad = self.conv_params[i] - 1 + + for layer_idx, enc_layer in enumerate(module.layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim, device=device) + cached_value = torch.zeros(downsample_left, batch_size, value_dim, device=device) + cached_conv = torch.zeros(batch_size, embed_dim, conv_left_pad, device=device) + cached_norm_stats, cached_norm_len = enc_layer.norm.get_init_cache(batch_size) + cached_norm_stats = cached_norm_stats.to(device) + cached_norm_len = cached_norm_len.to(device) + + caches.extend([ + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + ]) + + return caches + + +def pad_mask(mask: Optional[Tensor], seq_len: int): + # mask: (batch_size, old_seq_len) + # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). + if mask is None: + return None + (batch_size, old_seq_len) = mask.shape + pad = seq_len - old_seq_len + if pad == 0: + return mask + else: + return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), + dim=1) + + +def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + if downsampling_factor == 1: + return x + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) + return x + +def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + if upsampling_factor == 1: + return x + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + return x + + +def get_dct_matrix(N): + """ + Generates an orthonormal DCT-II matrix for a given size N. + Args: + N (int): The size of the square matrix. + Returns: + torch.Tensor: The N x N orthonormal DCT-II matrix. + """ + # Create the base matrix with dimensions (N, N) + mat = torch.zeros(N, N) + # Create a tensor for the indices k (rows) and n (columns) + k = torch.arange(N).unsqueeze(1) + n = torch.arange(N).unsqueeze(0) + # Fill the matrix using the DCT-II formula + mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) + # Adjust the first row (k=0) with a special normalization factor + mat[0] *= (2 ** -0.5) + return mat + + +class ZapformerEncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_multiple: determines the hidden dimension of the feedforward module + + conv_params (int): params per channel of convolution module + + Examples:: + >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + value_head_dim: int, + pos_head_dim: int, + feedforward_multiple: int, + conv_params: int, + causal: bool = False, + ) -> None: + super(ZapformerEncoderLayer, self).__init__() + self.embed_dim = embed_dim + self.name = None # will be set from training loop + + self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) + + power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) + + self.self_attn = MultiheadRelPosGatedSelfAttention( + embed_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + value_head_dim=value_head_dim, + pos_head_dim=pos_head_dim, + ) + + feedforward_dim = embed_dim * feedforward_multiple + self.feed_forward1 = FeedforwardModule(embed_dim, feedforward_dim) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim) + + self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) + + self.norm = CausalSequenceNorm() if causal else SequenceNorm() + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + 2. * aux_loss_scale, mask=src_key_padding_mask), + None) + + src_pre_ff1 = src + + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + # may try changing src_pre_ff1 to src or vice versa. + src = src + self.self_attn(src_pre_ff1, src, attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.conv_module(3. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + residual_scale = 0.25 + offset = (src - src_orig) * residual_scale + + offset = self.offset_scale_limiter(offset, aux_loss_scale) + + src = src_orig + offset + + src = self.norm(src, src_key_padding_mask) + + return src + + def streaming_forward( + self, + src: Tensor, + cached_key: Tensor, + cached_value: Tensor, + cached_conv: Tensor, + cached_norm_stats: Tensor, + cached_norm_len: Tensor, + left_context_len: int, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim) + cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim) + cached_conv: cached left context for the convolution module, of shape (batch_size, channels, left_pad) + cached_norm_stats: cached SequenceNorm stats, of shape (batch_size,) + cached_norm_len: cached SequenceNorm length, scalar. + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); + True means masked position. May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_value + - updated cached_conv + - updated cached_norm_stats + - updated cached_norm_len + """ + src_orig = src + + src_pre_ff1 = src + + chunk_mask = None if src_key_padding_mask is None else src_key_padding_mask[:, left_context_len:] + + src = src + self.feed_forward1(src, src_key_padding_mask=chunk_mask) + + # may try changing src_pre_ff1 to src or vice versa. + self_attn_out, cached_key, cached_value = self.self_attn.streaming_forward( + x_qkp=src_pre_ff1, + x_vg=src, + left_context_len=left_context_len, + cached_key=cached_key, + cached_value=cached_value, + key_padding_mask=src_key_padding_mask, + ) + src = src + self_attn_out + + src_conv, cached_conv = self.conv_module.streaming_forward( + 3.0 * src, + cache=cached_conv, + src_key_padding_mask=chunk_mask, + ) + src = src + src_conv + + src = src + self.feed_forward2(src, src_key_padding_mask=chunk_mask) + + residual_scale = 0.25 + offset = (src - src_orig) * residual_scale + + src = src_orig + offset + + src, cached_norm_stats, cached_norm_len = self.norm.streaming_forward( + src, + cached_stats_sum=cached_norm_stats, + cached_len=cached_norm_len, + ) + + return ( + src, + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + ) + + +class ZapformerEncoder(nn.Module): + r"""ZapformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZapformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). + + Examples:: + >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) + >>> zapformer_encoder = ZapformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zapformer_encoder(src) + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dim: int, + ) -> None: + super().__init__() + + # self.downsample will also reverse the downsampling operation for us afterward. + self.proj = OrthogonalLinear(dim, encoder_layer.embed_dim, + lr_scale=0.66, bias=False) + + self.name = None + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.residual_scales = nn.Parameter( + torch.cat([ -1.0 * torch.ones(1), + (1. / num_layers) * torch.ones(num_layers) ], + dim=0)) + + self.input_scale = nn.Parameter(torch.tensor([1.0])) + + self.copy_bypass = Identity() + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + (out, out_sd), both of the same shape as src, + where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. + """ + src_orig_fulldim = src + + src = self.proj(src) # project to layer dim. + + num_layers = len(self.layers) + src_orig = src + + residual_scale = limit_param_value(self.residual_scales[0], + min=-1.0, max=-0.5) + input_scale = limit_param_value(self.input_scale, + min=0.5, max=2.0) + + src_with_bypass = residual_scale * src + src = input_scale * src + + for i, mod in enumerate(self.layers): + + src = mod( + src, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, + ) + residual_scale = limit_param_value(self.residual_scales[i + 1], + min=0.0 if i + 1 < num_layers else 0.5, + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src + + + offset = src_with_bypass + + src = src_orig_fulldim + self.proj(offset, transpose=True) + # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, + # because of some identities involving orthogonal matrices. + + return src + + def streaming_forward( + self, + src: Tensor, + caches: List[Tensor], + left_context_len: int, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn in streaming mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). + caches: list of cached tensors of N encoder layers. For layer-i, + caches[i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated caches + """ + src_orig_fulldim = src + + # project to layer dim. + src = self.proj(src) + + num_layers = len(self.layers) + assert len(caches) == num_layers * 5 + + residual_scale = self.residual_scales[0] + input_scale = self.input_scale + + src_with_bypass = residual_scale * src + src = input_scale * src + + new_caches = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + ) = caches[i * 5 : (i + 1) * 5] + + ( + src, + new_cached_key, + new_cached_value, + new_cached_conv, + new_cached_norm_stats, + new_cached_norm_len, + ) = mod.streaming_forward( + src, + cached_key=cached_key, + cached_value=cached_value, + cached_conv=cached_conv, + cached_norm_stats=cached_norm_stats, + cached_norm_len=cached_norm_len, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + + layer_residual_scale = self.residual_scales[i + 1] + + src_with_bypass = src_with_bypass + layer_residual_scale * src + + new_caches.extend([ + new_cached_key, + new_cached_value, + new_cached_conv, + new_cached_norm_stats, + new_cached_norm_len, + ]) + + offset = src_with_bypass + src = src_orig_fulldim + self.proj(offset, transpose=True) + + return src, new_caches + + + + +class MultiheadRelPosGatedSelfAttention(nn.Module): + r""" + Module that computes multi-head attention weights with additive relative-position + scores that are kept separate from the regular scores. The values have gating. + An RMSNorm module is used to pre-normalize the input embedding only as it is + input to the queries and keys, not the values. + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int = 4, + value_head_dim: int = 12, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.name = None # will be overwritten in training code; for diagnostics. + + self.in_norm = RmsNorm() + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.qkp_in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125 * query_head_dim**-0.25 + ) + + self.rel_pos = RelPosScores(num_heads, pos_head_dim, num_freqs=64) + + self.copy_query = Identity() + self.copy_pos_query = Identity() + + # value and gating in_proj. + self.vg_in_proj = ScaledLinear(embed_dim, 2 * num_heads * value_head_dim, + initial_scale=0.1, bias=True) + + self.copy_v = nn.Identity() # diagnostics. + self.sigmoid = nn.Sigmoid() + + # out proj for the value times gating. + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 + ) + + def forward( + self, + x_qkp: Tensor, + x_vg: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r""" + Args: + x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, + keys and positions. + x_vg: input of shape (seq_len, batch_size, embed_dim), that is used for the values + and gates. May be the same as x_qk. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + query_head_dim = self.query_head_dim + num_heads = self.num_heads + x_qkp = self.in_norm(x_qkp) + x_qkp = self.qkp_in_proj(x_qkp) + + seq_len, batch_size, _ = x_qkp.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x_qkp[..., 0:query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] + p = x_qkp[..., 2 * query_dim:] + + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, -1) + + #q = self.rope(q.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + #k = self.rope(k.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) + + p = p.permute(1, 2, 0, 3) + pos_scores = self.rel_pos(p) # (batch, head, time1, time2) + attn_scores = attn_scores + pos_scores.permute(1, 0, 2, 3) + + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, + 0.1 * aux_loss_scale, + key_padding_mask, self.name) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + + v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) + v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + v = self.copy_v(v) + value_head_dim = v.shape[-1] + # now v: (num_heads, batch_size, seq_len, value_head_dim) + + # todo: see whether there is benefit in overriding matmul + v = torch.matmul(attn_weights, v) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + v = ( + v.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + if self.training: + # don't let the sigmoid values get too extreme, limit to -2..2. + g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + v = v * self.sigmoid(g) + v = self.out_proj(v) + return v + + def streaming_forward( + self, + x_qkp: Tensor, + x_vg: Tensor, + left_context_len: int, + cached_key: Tensor, + cached_value: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Args: + x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, + keys and positions. + x_vg: input of shape (seq_len, batch_size, embed_dim), that is used for the values + and gates. May be the same as x_qk. + left_context_len: length of the cached left context. + cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim). + cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim). + key_padding_mask: a bool tensor of shape (batch_size, left_context_len + seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention output, of shape (seq_len, batch_size, embed_dim) + - updated cached_key, of shape (left_context_len, batch_size, key_dim) + - updated cached_value, of shape (left_context_len, batch_size, value_dim) + """ + query_head_dim = self.query_head_dim + num_heads = self.num_heads + x_qkp = self.in_norm(x_qkp) + x_qkp = self.qkp_in_proj(x_qkp) + + seq_len, batch_size, _ = x_qkp.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x_qkp[..., 0:query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] + p = x_qkp[..., 2 * query_dim:] + + # append the cached key to the current key, and update the cache + assert cached_key.shape[0] == left_context_len, (cached_key.shape, left_context_len) + k = torch.cat([cached_key, k], dim=0) + kv_len = k.shape[0] + cached_key = k[kv_len - left_context_len:] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(kv_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, -1) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, query_head_dim, time2) + + attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) + + p = p.permute(1, 2, 0, 3) + pos_scores = self.rel_pos(p, left_context_len) # (batch, head, time1, time2) + attn_scores = attn_scores + pos_scores.permute(1, 0, 2, 3) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, kv_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, kv_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1), -1000) + + attn_weights = attn_scores.softmax(dim=-1) + + v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) + + # append the cached value to the current value, and update the cache + assert cached_value.shape[0] == left_context_len, (cached_value.shape, left_context_len) + v = torch.cat([cached_value, v], dim=0) + cached_value = v[kv_len - left_context_len:] + + v = v.reshape(kv_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + value_head_dim = v.shape[-1] + # now v: (num_heads, batch_size, kv_len, value_head_dim) + + # todo: see whether there is benefit in overriding matmul + v = torch.matmul(attn_weights, v) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + v = ( + v.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + v = v * self.sigmoid(g) + v = self.out_proj(v) + + return v, cached_key, cached_value + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class PenalizeLargeAttentionScores(torch.autograd.Function): + @staticmethod + def forward( + ctx, + attn_scores: Tensor, + limit: float, + aux_loss_scale: float, + key_padding_mask: Optional[Tensor], + name: str): + # attn_scores: (head, batch, query_time, key_time) + ctx.save_for_backward(attn_scores) + ctx.mask = key_padding_mask # has no grad + ctx.limit = limit + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return attn_scores + + @staticmethod + def backward( + ctx, + attn_scores_grad): + attn_scores, = ctx.saved_tensors + mask = ctx.mask + (num_heads, batch_size, seq_len, _) = attn_scores.shape + with torch.amp.autocast('cuda', enabled=False): + attn_scores = attn_scores.to(torch.float) + attn_scores = attn_scores.detach() + # attn_scores: (head, batch, query_time, key_time) + attn_scores.requires_grad = True + with torch.enable_grad(): + probs = attn_scores.softmax(dim=-1) + scaled_scores = attn_scores.abs() * probs + avg_scores = scaled_scores.sum(dim=-1) # (head, batch, query_time) + if mask is not None: + avg_scores = avg_scores * (~mask) # mask: (batch, time) + query_scores = (avg_scores - ctx.limit).relu() + + if random.random() < 0.0005: + query_excess = query_scores.mean(dim=(1,2)).to('cpu') + avg_scores_mean = avg_scores.mean(dim=(1,2)).to('cpu') + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, avg_scores={avg_scores_mean}, query_excess={query_excess}") + # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number + # of frames which is batch_size * seq_len. normalize by dividing by num heads. + # also divide by ctx.limit so it's like penalizing a relative excess. + query_scores.backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / (num_heads * ctx.limit))) + + return attn_scores_grad + attn_scores.grad, None, None, None, None + + + + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zapformer model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int): + super(FeedforwardModule, self).__init__() + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 + + self.out_proj = ActivationAndLinear( + feedforward_dim, + embed_dim, + activation="SwashR", + initial_scale=0.5, + bias=True, + ) + + + def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + x = self.in_proj(x) + x = self.out_proj(x) + return x + +def round_up_to_power_of_two(x): + x = x - 1 + x = x | x >> 1 + x = x | x >> 2 + x = x | x >> 4 + x = x | x >> 8 + x = x | x >> 16 + x = x + 1 + return x + + +# wolfram alpha: +# the right part of the triangular bin, from 0 to +W. +# definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega +# = -(i W x + e^(-i W x) - 1)/(W x^2) +# Re[definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega] +# = (1 - cos(W x))/(W x^2) +# Im[definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega] +# = (sin(W x) - W x)/(W x^2) + +# the left part of the triangular bin, from -W to 0. +# definite integral from omega = -W to 0 of (omega/W + 1) exp(-i x \omega) d\omega +# (i W x - e^(i W x) + 1)/(W x^2) +# +# Let the center frequency be C. +# right side: +# = e^(i C x) * -(i W x + e^(-i W x) - 1)/(W x^2) +# "alternate form including W, C and x are real": [note, this is left hand width, W_l] +# (W x sin(C x) - cos(x (C - W)) + cos(C x))/(W x^2) - (i (sin(x (C - W)) + W x cos(C x) - sin(C x)))/(W x^2) +# +# left side: +# e^(i C x) * (i W x - e^(i W x) + 1)/(W x^2) +# "alternate form including W, C and x are real": [note, this is right hand width, W_r] +# -(W x sin(C x) + cos(x (C + W)) - cos(C x))/(W x^2) + (i (-sin(x (C + W)) + W x cos(C x) + sin(C x)))/(W x^2) +# +# summing the left and right sides: +# Real part: +# +# (W_r x sin(C x) - cos(x (C - W_r)) + cos(C x))/(W_r x^2) +# -(W_l x sin(C x) + cos(x (C + W_l)) - cos(C x))/(W_l x^2) +# = (cos(C x) - cos((C - W_r)x)) / W_r x^2 +# + (cos(C x) - cos((C + W_l)x)) / W_l x^2 + +# Imaginary part: +# -(sin(x (C - W_r)) + W_r x cos(C x) - sin(C x))) / (W_r x^2) +# +(-sin(x (C + W_l)) + W_l x cos(C x) + sin(C x)) / (W_l x^2) +# = ( sin(C x) - sin((C - W_r)x) ) / (W_r x^2) +# + ( sin(C x) - sin((C + W_l)x) ) / (W_l x^2) + +def compute_angular_freq_basis_triangular(freqs: Tensor, + t: Tensor, + scale: bool) -> Tensor: + """ + This function computes a set of windowed sinusoidal functions + corresponding to the real and imaginary parts of possibly-asymmetrical + triangular angular-frequency bins in frequency space. This basis + allows you to approximate functions whose fourier spectrum is + a piecewise linear function of frequency, with the x-axis values of + the inflection points of the piecewise linear function corresponding + to the supplied "freqs". + + Args: + freqs: the frequencies of the triangular-bin centers; the left and + right parts of the widths of the triangular bins correspond to the + distances to the two adjacent bins; for the "edge" bins, the + "edge" distances are duplicated. + t: the "t" (or x) values for which we want to evaluate the basis; this + will normally be some kind of arange expression e.g. arange(100). + scale: if True, the returned basis will contain the "natural" scaling + factors that arise from the bin widths; if False, it will be + normalized so that the maximum absolute value of the real + functions (attained at t==0) is 1. + + + Returns: + Returns the real and imaginary parts of the basis functions, with + shape (t.size(), freqs.size(), 2) + """ + dtype = freqs.dtype + freqs = freqs.to(torch.double) + t = t.to(torch.double) + + t = t.unsqueeze(-1) + + + C = freqs # Center frequencies of bins. + W = freqs[1:] - freqs[:-1] # the differences between the frequencies + W_l = torch.cat((W[:1], W)) # the difference between each center freq and the freq to the left + W_r = torch.cat((W, W[-1:])) # the difference between each center freq and the freq to the right + + angles = C * t + angles_r = (C - W_r) * t + angles_l = (C + W_l) * t + t2 = t**2 + scale_factor = 0.5 * (W_r + W_l) + + re = torch.where(t == 0., scale_factor, + (angles.cos() - angles_r.cos()) / (W_r * t2) + (angles.cos() - angles_l.cos()) / (W_l * t2)) + im = torch.where(t == 0., 0.0, + (angles.sin() - angles_r.sin()) / (W_r * t2) + (angles.sin() - angles_l.sin()) / (W_l * t2)) + + + if not scale: + re = re / scale_factor + im = im / scale_factor + + return torch.stack((re, im), dim=-1).to(dtype) + + +class RelPosScores(nn.Module): + def __init__(self, + num_heads: int, + pos_head_dim: int, + num_freqs: int, + low_freq_factor: float = 0.001): + """ + Implementation of relative position scores; where conventional relative position scores + would use sinusoids, we treat each sinusoid frequency as the central frequency of a + triangular "bucket" (like the buckets in mel bins) of frequencies. What this amounts + to is that instead of a sinusoid we get something a bit like a sinusoid times a + sinc-squared function (the sinc-squared function is the fourier transform of a triangular + function). Actually it's not the sinc-squared funtion, it's a slightly more complicated + function than that because the "triangles" have uneven shapes, due to the center frequencies + of the triangles not being evenly spaced. + + Args: + num_heads: the number of heads + pos_head_dim: the dimension of the head; in a conventionally structured model this would + be identical to the query-dim but we make the "position query" independent of + the main query and with a smaller dimension. + num_freqs: the number of frequencies of the sin and cos functions + low_freq_factor: this is approximately the amount by which the lowest frequency will be + less than the highest frequency, the highest frequency being the Nyquist (pi). + The frequencies are close to a geometric series at higher frequency but linear + at low frequency. + """ + super().__init__() + self.weight = nn.Parameter(0.04 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) + with torch.no_grad(): + # initialize the weight in a low-pass way. I think this is not so critical + # actually, it may not matter. + for _ in range(10): + self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) + + log_freqs = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs) + freqs = math.pi * (log_freqs.exp() - low_freq_factor) # these range from 0 to pi. + freqs[0] = 0.0 # in case of roundoff (it should be 0, mathematically) + self.register_buffer('freqs', freqs, persistent=False) + + def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: + """ + Compute and return unnormalized log scores for relative position. + Args: + p: these are the position-queries, of shape (batch_size, num_heads, seq_len, pos_dim) + (they are obtained via projection, just like the queries). + left_context_len: length of left context, must be 0 for non-streaming forward and > 0 for streaming forward. + Returns: + scores: (batch_size, num_heads, dest_seq_len, src_seq_len), where dest_seq_len relates to the + query and src_seq_len to the key. + In non-streaming forward, dest_seq_len and src_seq_len are numerically equal to seq_len; + in streaming forward, dest_seq_len is seq_len and src_seq_len is seq_len + left_context_len. + """ + (batch_size, num_heads, seq_len, pos_dim) = p.shape + + freqs = self.freqs # base freqs + t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=p.device) + basis = compute_angular_freq_basis_triangular(freqs, t, scale=False) + # basis: (2 * seq_len + left_context_len - 1, num_freqs, 2) + basis = basis.permute(0, 2, 1) + # permute it because of how we did the low-pass initialization of weight, we want + # the cos and sin parts to each be continuous ranges, not interleaved. + basis = basis.reshape(basis.shape[0], -1) # (2 * seq_len + left_context_len - 1, 2 * num_freqs) + + x = torch.matmul(self.weight, basis.t()) + assert x.shape == (num_heads, pos_dim, 2 * seq_len + left_context_len - 1) + + # with seq_len2 = 2 * seq_len + left_context_len - 1, + # (batch, head, seq_len, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, seq_len, seq_len2) + pos_weights = torch.matmul(p, x) + + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. This is all copied from our old conformer/zapformer code. + if torch.jit.is_tracing(): + seq_len2 = pos_weights.shape[-1] + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(left_context_len + seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, seq_len2) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, seq_len, left_context_len + seq_len) + else: + pos_weights = pos_weights.as_strided( + (batch_size, num_heads, seq_len, left_context_len + seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + return pos_weights + + +def round_up_to_power_of_two(x): + x = x - 1 + x = x | x >> 1 + x = x | x >> 2 + x = x | x >> 4 + x = x | x >> 8 + x = x | x >> 16 + x = x + 1 + return x + + + +# FftConv was formerly used as the depthwise_conv module in ConvolutionModule. +# CAUTION: this is not used right now, we use BasisConv plus WeightedMean in +# parallel for the depthwise convolution in the ConvModule. Using FftConv is +# actually just as good in WER terms and is also more efficient, versus (BasisConv +# plus WeightedMean); FftConv itself, should be about twice faster because it +# operates on a twice-shorter length than BasisConv since BasisConv pads for +# exactness. For the overall training the speed difference is about 10%. +# The reason we use BasisConv is because it is properly invariant to +# how we pad different-length sequences into a batch, while FftConv cannot +# be made to give exactly the same results independent of the batch size, because +# it treats the signal as repeating in time which depends on the FFT size which +# depends on the longest sequence in the batch. Unfortunately, we don't know +# exactly how the model is going to be used and we don't want it to become +# deal-breaker that batching very-different-length sequences together in inference +# time could significantly affect the model results. For image tasks, +# FftConv may still be useful (after suitable adaptation), because +# you wouldn't normally try to inference different size images in a batch. +class FftConv(nn.Module): + def __init__(self, + num_channels: int, + params_per_channel: int, + bias: bool = True): + super().__init__() + self.weight = nn.Parameter(0.1 * torch.randn(num_channels, params_per_channel)) + # one factor of 2 is for (sin, cos); the other is to double the num representable freqs + self.weight_proj = nn.Linear(params_per_channel, 4 * params_per_channel) + if bias: + self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) + + + def forward(self, + x: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + + # select a power of two that's >= seq_len // 8 and round up seq_len + # to a multiple of that power. This means that rounded_seq_len + # will be of the form (2**n) * k where k <= 8, so it won't contain + # many factors other than two; this will make the FFT more efficient + # without adding an excessive amount of padding. + power_of_two = max(1, round_up_to_power_of_two(seq_len // 8)) + rounded_seq_len = power_of_two * ((seq_len + power_of_two - 1) // power_of_two) + + + with torch.amp.autocast('cuda', enabled=False): + # do it in float32 because non power of two seq_len is not supported in half precision. + x = torch.fft.rfft(x.to(torch.float32), dim=0, n=rounded_seq_len) + # x: (num_freqs, batch_size, num_channels) + N = x.shape[0] # num freqs + weight = 4. * self.weight + weight = self.weight_proj(weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) + # this scale of 10 times is because of interactions with commonly + # used optimizers, it's to help this module learn faster than it + # otherwise would. + weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) + weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) + # weight: (N, num_channels) + weight = weight.unsqueeze(1) # (N, 1, num_channels) + x = x * weight + x = torch.fft.irfft(x, n=rounded_seq_len, dim=0) + + x = x[:seq_len] + + try: + x = x + self.bias + except AttributeError: + pass + + return x + + +# convolution where we convolve with a combination of basis functions, the basis functions +# being based on linear interpolation in Fourier space-- in effect, each pair of basis functions +# corresponds to the real and imaginary coefficients for one triangular bin in Fourier space; +# in the time domain the triangular bin becomes a sinc^2 function and the frequency offset +# is just a complex exponential of which the real and imaginary coefficients give us sines and +# cosines. +def get_basis_funcs(seq_len: int, + num_freqs: int, + **kwargs +): + """ + seq_len: the sequence length to which the basis functions are truncated; this is expected to + be even + num_freqs: the number of frequencies; the number of basis functions will be 2 * num_freqs, + and note that the first pair of basis functions are special, because they are the + (zero-freq; nyquist-freq) ones. + kwargs: can be used for device + + Returns: + basis functions of shape: (2 * num_freqs, seq_len) + """ + assert seq_len % 2 == 0 + t = torch.cat((torch.arange(seq_len // 2, **kwargs), + torch.arange(-seq_len // 2, 0, **kwargs)), dim=0) # e.g. tensor([ 0, 1, 2, 3, -4, -3, -2, -1]) + # the second half of the "t" values are interpreted as the "negative half" of the time range-- + # the time range representing t values from -seq_len // 2 to seq_len // 2 - 1. + # The way we use this will be to convolve it with a signal of size seq_len // 2 that + # has been padded with zeroes of length seq_len // 2, and we want the result to be as if we padded with the basis + # functions from -infinity to infinity. + + + scaled_t = t * math.pi / num_freqs + + # "freqs" are the t values multiplied by the basis frequencies + t_freqs = scaled_t * torch.arange(num_freqs + 1, **kwargs).unsqueeze(-1) + # t_freqs: (num_freqs + 1, seq_len) + + # it's a sinc-squared envelope, as the frequency domain envelope is a + # triangular, not a rectangular, function. the factor of 0.5 comes + # from the math + sinc_arg = 0.5 * scaled_t + envelope = torch.where(sinc_arg != 0.0, sinc_arg.sin() / sinc_arg, torch.ones_like(sinc_arg)) ** 2 + + + cos, sin = t_freqs.cos() * envelope, t_freqs.sin() * envelope + #plt.plot(envelope) + + # the factor of 0.5 is because the other freqs would get "counted twice" due + # to having two symmetric versions, the freqs at zero and the nyquist only have + # one copy. This ensures that if we give a coeff of all ones on all the + # cos terms, we get (a scaled version of) the delta function. + sin[0] = 0.5 * cos[-1] + cos[0] = 0.5 * cos[0] + # the sin coefficient of freq 0 and nyquist gives us nothing, so we use the cos + # at the nyquist in this position. + cos = cos[:num_freqs] + sin = sin[:num_freqs] + #scale = num_freqs ** -0.5 # scale to make the funcs have a value around 1. + #cos = cos * scale + #sin = sin * scale + + basis = torch.cat((cos, sin), dim=0) + # basis: (2 * num_freqs, seq_len) + + #for i in range(num_freqs + 1): + # plt.plot(cos[i]) + # plt.plot(sin[i]) + # plt.show() + return basis + + +def fourier_conv(x: Tensor, y: Tensor): + # fourier based convolution of x and y, returns + # something with the same sequence length as the shorter of + # the two. + # x, y: (seq_len, [1 or batch_size], num_channels) + T = max(x.shape[0], y.shape[0]) + T_out = min(x.shape[0], y.shape[0]) + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + y = y.to(torch.float) + X = torch.fft.rfft(x, dim=0, n=T) + Y = torch.fft.rfft(y, dim=0, n=T) + return torch.fft.irfft(X * Y, dim=0, n=T)[:T_out] + +# fourier-based convolution, mem-efficient wrapper for fourier_conv. +class FourierConv(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return fourier_conv(x, y) + + @staticmethod + def backward(ctx, ans_grad): + # we could probably do a bit better than this by doing it manually + x, y = ctx.saved_tensors + with torch.enable_grad(): + x = x.detach() + y = y.detach() + x.requires_grad = True + y.requires_grad = True + ans = fourier_conv(x, y) + ans.backward(gradient=ans_grad) + return x.grad, y.grad + + +class WeightedMean(nn.Module): + # this is like the core part of squeeze-and-excite: it computes a mean over time, + # and then multiplies it by a learnable channel-specific weight. + # we add this to a more conventional convolution; we found this was helpful because + # normal convolution cannot do averaging-over-time since it does not know the + # sequence length. + def __init__(self, + num_channels: int, + causal: bool = False): + super().__init__() + self.causal = causal + self.weights = nn.Parameter(0.1 * torch.randn(num_channels)) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute weighted mean. + x: (time, batch, channel) + src_key_padding_mask: (batch, time), True for masked positions + + Returned shape: (time, batch, channel) if causal else (batch, channel) + """ + T = x.shape[0] + if self.causal: + num_frames = torch.arange(1, T + 1, device=x.device) + x_cumsum = torch.cumsum(x, dim=0) + return x_cumsum * num_frames[:, None, None] * self.weights + + + # assume x already masked, if mask is in use. + if src_key_padding_mask is not None: + num_frames = src_key_padding_mask.logical_not().to(torch.float).sum(dim=1) + num_frames = num_frames.unsqueeze(-1).to(torch.float) + + # num_frames: (batch_size, 1) + return x.mean(dim=0) * (T / num_frames) * self.weights + else: + return x.mean(dim=0) * self.weights + +class BasisConv(nn.Module): + def __init__(self, + num_channels: int, + num_freqs: int, + params_per_channel: int): + super().__init__() + self.weight_proj = nn.Linear(params_per_channel, 2 * num_freqs) + + self.weight = nn.Parameter(0.05 * torch.randn(num_channels, + params_per_channel)) + + + def forward(self, + x: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + + + # round seq_len to a multiple of "round" to help ensure the FFT dimension + # has plenty of powers of two; this will tend to make it more efficient. + round = min(16, round_up_to_power_of_two(seq_len)) + seq_len_rounded = round * ((seq_len + round - 1) // round) + + # to ensure the answer is the same regardless of the amount of padding, we + # pad the sequence to at least twice its initial length for purposes of + # the FFT-based convolution. Because we will view the basis functions + # as going from t=-seq_len_rounded to t=seq_len_rounded - 1, this will + # ensure that we never see "wrap-around" effects. + T = 2 * seq_len_rounded + + num_freqs = self.weight_proj.weight.shape[0] // 2 + basis_funcs = get_basis_funcs(T, num_freqs, device=x.device) + # basis_funcs: (2 * num_freqs, T) + + scale = num_freqs ** -0.5 + + weight = scale * self.weight_proj(self.weight) + # weight: (num_channels, 2 * num_freqs) + channel_funcs = torch.matmul(weight, basis_funcs) + # channel_funcs: (num_channels, T) + + + # channel_funcs: (num_channels, T) + channel_funcs = channel_funcs.t().unsqueeze(1) + # channel_funcs: (T, 1, num_channels) + + return FourierConv.apply(channel_funcs, x) + + + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zapformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 3 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + self.activation1 = Identity() # for diagnostics + + self.sigmoid1 = nn.Sigmoid() + + self.sigmoid2 = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + + if not causal: + self.depthwise_conv = BasisConv(bottleneck_dim, + num_freqs=kernel_size*2, + params_per_channel=kernel_size) + else: + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=0, # will pad manually, on one side. + bias=True, + ) + self.left_pad = kernel_size - 1 + + self.depthwise_conv.lr_scale = 0.66 + # add average-of-all-frames to the "convolution."; it has extra power vs the convolution + # because the num frames differs between utterances. + self.weighted_mean = WeightedMean(bottleneck_dim, + causal=causal) + + self.out_proj = ActivationAndLinear( + bottleneck_dim, + channels, + activation="SwashR", + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + """ + + x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) + + x, s, y = x.chunk(3, dim=2) + s = self.sigmoid1(s) + y = self.sigmoid2(y) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + + # x: (time, batch, channels) + # Caution: this module is not completely + # invariant to the number of frames each sequence is padded with, since + # the FFT-based convolution treats the signal as repeating. + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + + wm = self.weighted_mean(x, src_key_padding_mask) + if self.causal: + # Not support exporting a model for simulated streaming decoding + assert not torch.jit.is_scripting() and not torch.jit.is_tracing() + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) + x_shape = x.shape + x = torch.nn.functional.pad(x, (self.left_pad, 0)) + x = self.depthwise_conv(x) + assert x.shape == x_shape, (x.shape, x_shape) + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) + else: + x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) + x = x + wm # Add in the weighted-mean to the convolution; this adds extra power + # because the utterances differ in length. + + x = x * y + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv, of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) + + x, s, y = x.chunk(3, dim=2) + s = self.sigmoid1(s) + y = self.sigmoid2(y) + x = x * s + + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x_shape = x.shape + assert cache.shape[-1] == self.left_pad, (cache.shape[-1], self.left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -self.left_pad:] + + x = self.depthwise_conv(x) + assert x.shape == x_shape, (x.shape, x_shape) + + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) + + x = x * y + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +def _test_zapformer_main(causal: bool = False): + seq_len = 20 + # Just make sure the forward pass runs. + + input_dim = 50 + + c = Zapformer( + input_dim=input_dim, + encoder_dim=(64, 96), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + + batch_size = 6 + seq_len = 21 + # Just make sure the forward pass runs. + f, lengths = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + ) + f.sum().backward() + c.eval() + x_ = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + ) + x_ # to remove flake8 warnings + + logging.info(f"Zapformer forward test passed, causal={causal}") + + +def _test_zapformer_streaming(): + input_dim = 50 + batch_size = 2 + chunk_size = 32 + num_chunks = 10 + tail_chunk_size = 8 + seq_len = chunk_size * num_chunks + tail_chunk_size + left_context_frames = 128 + + model = Zapformer( + input_dim=input_dim, + encoder_dim=(64, 96, 128, 96), + num_heads=(4, 4, 4, 4), + conv_params=(31, 31, 15, 31), # it may be better to make these even if not in causal mode. + downsampling_factor=(1, 2, 4, 2), + causal=True, + chunk_size=(chunk_size,), + left_context_frames=(left_context_frames,), + ) + + model.eval() + + x_full = torch.randn(seq_len, batch_size, input_dim) + x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) + + with torch.no_grad(): + out_full, out_lens_full = model(x_full, x_lens_full) + + caches = model.get_init_caches(batch_size=batch_size) + + out_chunks = [] + out_offset = 0 + processed_lens = torch.full((batch_size,), 0, dtype=torch.int64) + + for i in range(num_chunks): + start = i * chunk_size + end = start + chunk_size + x_chunk = x_full[start:end] + x_lens = torch.full((batch_size,), chunk_size, dtype=torch.int64) + + src_key_padding_mask = make_pad_mask(x_lens) + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_frames).expand(batch_size, left_context_frames) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + out_chunk, out_lens, caches = model.streaming_forward( + x=x_chunk, + x_lens=x_lens, + caches=caches, + src_key_padding_mask=src_key_padding_mask, + ) + out_chunks.append(out_chunk) + + out_chunk_len = out_chunk.shape[0] + expected_out = out_full[out_offset : out_offset + out_chunk_len] + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) + logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") + assert torch.allclose(expected_out, out_chunk, atol=2e-5), f"Chunk {i+1} outputs do not match! Max diff: {diff_chunk}" + + out_offset += out_chunk_len + + x_tail = x_full[num_chunks * chunk_size:] + x_lens_tail = torch.full((batch_size,), tail_chunk_size, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens_tail) + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_frames).expand(batch_size, left_context_frames) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + processed_lens = processed_lens + x_lens_tail + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + out_tail, out_lens_tail, caches = model.streaming_forward( + x=x_tail, + x_lens=x_lens_tail, + caches=caches, + src_key_padding_mask=src_key_padding_mask, + ) + out_chunks.append(out_tail) + + out_tail_len = out_tail.shape[0] + expected_out_tail = out_full[out_offset : out_offset + out_tail_len] + diff_tail = torch.max(torch.abs(expected_out_tail - out_tail)) + logging.info(f"Tail Chunk | Input: {x_tail.shape} -> Output: {out_tail.shape} | Max diff: {diff_tail}") + assert torch.allclose(expected_out_tail, out_tail, atol=2e-5), f"Tail Chunk outputs do not match! Max diff: {diff_tail}" + out_offset += out_tail_len + + out_stream_cat = torch.cat(out_chunks, dim=0) + + diff = torch.max(torch.abs(out_full - out_stream_cat)) + logging.info(f"Max abs diff between full forward and streaming forward: {diff}") + + assert torch.allclose(out_full, out_stream_cat, atol=2e-5), f"Outputs do not match! Max diff: {diff}" + + logging.info("Passed") + + + +def _test_basis_conv(): + num_channels = 11 + f = BasisConv(num_channels=num_channels, + num_freqs=4, + params_per_channel=2) + + seq_len = 100 + subseq_len = 10 # will help visualize the effect + batch_size = 2 + x = torch.cat((torch.randn(subseq_len, batch_size, num_channels), + torch.zeros(seq_len - subseq_len, batch_size, num_channels)), + dim=0) + + y = f(x) + + #plt.plot(x[:, 0, 0].detach()) + #plt.plot(y[:, 0, 0].detach()) + #plt.show() + + + def rms(a): + return (a**2).mean().item() + print(f"rms(x)={rms(x)}, rms(y)={rms(y)}") + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_basis_conv() + _test_zapformer_main(False) + _test_zapformer_main(True) + _test_zapformer_streaming() diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py new file mode 100644 index 0000000000..77a3b12640 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -0,0 +1,999 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import copy +import random +from typing import Optional, Tuple, Union, Any + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + + + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + + +# all arg tensors except x are scalars. +def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): + stats = (x ** 2).mean(dim=2, keepdim=True) + T = x.shape[0] # time + if mask is None: + stats = stats.sum(dim=0) + lengths = T + else: + mask = (~mask).to(torch.float).t().unsqueeze(-1) + stats = stats * mask + stats = stats.sum(dim=0) + lengths = mask.sum(dim=0) + + scales = (lengths / stats).sqrt() + assert scales.shape == (x.shape[1], 1) + return x * ((scale * scales) + offset) + +# all arg tensors except x are scalars. +def _causal_sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor): + stats = (x ** 2).mean(dim=2, keepdim=True) + + # no need for mask in causal mode. + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * ballast_frames.abs() + ballast = ballast_frames * (ballast_rms ** 2) + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + ballast + lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + scales = (lengths / stats).sqrt() + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset) + + +# all arg tensors are scalars +def _causal_sequence_norm_streaming( + x: Tensor, + offset: Tensor, + scale: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Streaming inference forward for _sequence_norm. We assume that ballast_frames and ballast_rms + are already included in cached_stats_sum and cached_len. + + Args: + x: (seq_len, batch_size, channels) + offset: scalar + scale: scalar + cached_stats_sum: (batch_size,) + cached_len: (batch_size,) + + Returns: + - normalized x, (seq_len, batch_size, channels) + - updated cached_stats_sum, (batch_size,) + - updated cached_len, (batch_size,) + """ + stats = (x ** 2).mean(dim=2, keepdim=True) # (seq_len, batch_size, 1) + + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + cached_stats_sum.unsqueeze(-1) + lengths = cached_len[:, None] + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + # update cached_stats_sum and cached_len for the next chunk + cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) + cached_len = cached_len + T + + scales = (lengths / stats).sqrt() # (T, batch_size, 1) + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset), cached_stats_sum, cached_len + + +class CausalSequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + ballast_rms: Tensor, + ballast_frames: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) + + return _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors + + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + ballast_rms = ballast_rms.to(torch.float32).detach().requires_grad_() + ballast_frames = ballast_frames.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad) + +class SequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + mask: Optional[Tensor], + ) -> Tensor: + ctx.save_for_backward(x, offset, scale) + ctx.mask = mask + + return _sequence_norm(x, offset, scale, mask) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, offset, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _sequence_norm(x, offset, scale, ctx.mask) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x if x is None else x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), None + + +class CausalSequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + up to the current point as well as the channels, with some padding of the stats with "default values" + determined by ballast_frames, ballast_rms for robustness near the beginning of the sequence. + + There is also a learnable scalar scale, multiplicatively applied to the output, and a learnable + "offset" value that acts multiplicatively on the input without taking into account the rms values. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + + # ballast_mean: assumed rms value of ballast frames used to pad stats + self.ballast_rms = nn.Parameter(torch.tensor(0.1)) + # ballast_frames: number of ballast frames, in hundreds (will be multiplied by 100) + self.ballast_frames = nn.Parameter(torch.tensor(0.05)) # number of ballast frames, will be multiplied by 100 + self.name = None + + def forward(self, x: Tensor, _mask: Optional[Tensor] = None) -> Tensor: + # x: (seq, batch, channel) + # The mask is ignored, it is allowed only for consistency of interface with SequenceNorm. + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _causal_sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ballast_rms = limit_param_value( + self.ballast_rms, min=0.0, max=10.0, training=self.training) + + ballast_frames = limit_param_value( + self.ballast_frames, min=0.0, max=5.0, training=self.training) # max of 5.0 would be 500 frames + + ans = CausalSequenceNormFunction.apply( + x, offset, scale, ballast_rms, ballast_frames, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}, ballast_rms={self.ballast_rms.item()}, ballast_frames*100={100*self.ballast_frames.item()}") + + return ans + + @torch.jit.export + def get_init_cache(self, batch_size: int): + """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. + """ + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * self.ballast_frames.abs() + ballast = ballast_frames * (self.ballast_rms ** 2) + + cached_stats_sum = ballast.unsqueeze(0).repeat(batch_size) # (batch_size,) + cached_len = ballast_frames.unsqueeze(0).repeat(batch_size) # (batch_size,) + + return cached_stats_sum, cached_len + + def streaming_forward( + self, + x: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + + x, cached_stats_sum, cached_len = _causal_sequence_norm_streaming( + x, self.offset, self.scale, cached_stats_sum, cached_len) + return x, cached_stats_sum, cached_len + + +class SequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + as well as the channels; and a padding mask is used for irregular length sequences (actually, + the mask is applied multiplicatively as well.) + + There is also a learnable scalar scale and a learnable "offset" value. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + self.name = None + + def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + # x: (seq, batch, channel) + # mask: bool, (batch_size, seq_len) + # Note: mask is ignored in causal mode. + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _sequence_norm(x, self.offset, self.scale, mask) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ans = SequenceNormFunction.apply( + x, offset, scale, mask, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") + + return ans + + + +# assume layout: (time, batch, channel) +def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) + scales = scale / x_sq.sqrt() + return x * scales + + +class RmsNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + eps: Tensor, + scale: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, eps, scale) + return _rms_norm(x, eps, scale) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, eps, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, eps, scale = x.detach(), eps.detach(), scale.detach() + + x.requires_grad = True + eps.requires_grad = True + scale.requires_grad = True + + with torch.enable_grad(): + ans = _rms_norm(x, eps, scale) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode. + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(eps.grad), c(scale.grad) + + +class RmsNorm(torch.nn.Module): + """ + This is RMSNorm with a trainable scale and trainable epsilon. + """ + def __init__( + self, + ) -> None: + super(RmsNorm, self).__init__() + self.scale = nn.Parameter(torch.tensor(0.2)) # output scale + self.eps = nn.Parameter(torch.tensor(0.1)) + self.name = None + + + def forward(self, x: Tensor) -> Tensor: + # Assumes layout is (time, batch, channel) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _rms_norm(x, self.eps, self.scale) + + scale = limit_param_value( + self.scale, min=0.05, max=1.0, training=self.training) + + eps = limit_param_value( + self.eps, min=0.0, max=10.0, training=self.training) + + ans = RmsNormFunction.apply( + x, eps, scale, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, eps={eps.item()}, scale={scale.item()}") + + return ans + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.01 * initial_scale, 0.01 * initial_scale) + return ans + + +class OrthogonalPenaltyFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, weight: Tensor, penalty_scale: float, name: str): + ctx.save_for_backward(weight) + ctx.name = name + ctx.penalty_scale = penalty_scale + return weight + + @staticmethod + @custom_bwd + def backward(ctx, weight_grad): + weight, = ctx.saved_tensors + + if weight.requires_grad and ctx.penalty_scale != 0.0: + penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() + + with torch.enable_grad(): + weight = weight.detach() + weight.requires_grad = True + + # Compute symmetric matrix-product prod with the smallest + # dimension possible given the shape of w. This is not just for + # efficiency; if we computed it the wrong way round, the product + # would have deficient rank and could never be the identity. + if (weight.shape[0] > weight.shape[1]): + prod = torch.matmul(weight.t(), weight) + else: + prod = torch.matmul(weight, weight.t()) + + # we'll try to enforce that for any i, prod[i] is any constant times the identity. + + # in the loss-function: + # orthogonality_loss = ((prod - I) ** 2).sum(), + + # note, prod_diag shares memory with prod, this will matter later on. + (r, c) = prod.shape + (r_stride, c_stride) = prod.stride() + + def diag_inplace(z): + return torch.as_strided(z, size=(r,), stride=(r_stride+c_stride,)) + + diag_inplace(prod)[:] -= 1. + + # that loss that we want to backprop would be 0.5 * (prod ** + # 2).sum() * penalty_scale. we can backprop this without doing + # any reductions as follows: + prod.backward(gradient=prod * penalty_scale) + + + do_print = random.random() < 0.002 + if do_print: + # we print a normalized version of the loss, by dividing by the + # number of rows. + loss = (prod ** 2).mean() + logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") + + + # add the extra gradient term from the orthogonality loss. + weight_grad = weight_grad + weight.grad + return weight_grad, None, None + +class OrthogonalLinear(nn.Linear): + """ + Like nn.Linear but can enforce that the weight matrix is orthogonal; in the non-square + case this is interpreted as either M^T M == I or M M^T == I, whichever would give a smaller + dimension. + (If M is square, these definitions are equivalent and is equivalent to the normal + definition of orthogonal). + + Args: + in_channels: number of input channels + out_channels: number of output channels + lr_scale: we will scale the weight by this value before applying the orthogonal + constraint and using it; with most optimizers + this will have the effect of slowing down the learning by this factor because + the parameter value will be larger. + bias: if True, include a bias term. + penalty_scale: a scale on the penalty on non-orthogonality (this will + be multiplied by the average-absolute-value of the + backpropagated gradient). + """ + # if in_groups or out_groups are set to >1, the orthogonal constraint + # will be set per group. both of them cannot be >1. + def __init__(self, + in_channels: int, + out_channels: int, + lr_scale: float = 1.0, + bias: bool = True, + penalty_scale: float = 20.0, + ): + super().__init__(in_channels, out_channels, bias=bias) + self.name = None + self.penalty_scale = copy.deepcopy(penalty_scale) + self.lr_scale = lr_scale + + with torch.no_grad(): + self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * (1. / lr_scale) + if self.bias is not None: + torch.nn.init.uniform_(self.bias, -0.01, 0.01) + + + def forward(self, x: Tensor, transpose: bool = False): + # you can only use transpose=True if you used bias=False in initialization + weight = self.weight + lr_scale = self.lr_scale + if lr_scale != 1.0: + weight = weight * lr_scale + if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): + weight = OrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) + + if transpose: + weight = weight.t() + return torch.nn.functional.linear(x, weight, self.bias) + + +class ScaleLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, max_rms: float, aux_loss_scale: float, name: str): + ctx.save_for_backward(x) + ctx.max_rms = max_rms + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + x, = ctx.saved_tensors + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + rms = (x ** 2).mean(dim=-1).sqrt() + numel = rms.numel() + + excess = (rms / ctx.max_rms - 1.).relu().mean() + + if random.random() < 0.002: + logging.info( + f"ScaleLimiter: name={ctx.name}, max_rms={ctx.max_rms}, " + f"rms={rms.mean().item()}, excess={excess.item()}, " + f"loss_scale={ctx.aux_loss_scale}" + ) + excess.backward(gradient=torch.full_like(excess, ctx.aux_loss_scale * numel)) + return x_grad + x.grad, None, None, None + + +class ScaleLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if the norm of any activation vector is less than min_rms + or more than max_rms. + + Assumes channel dim is -1 and the input shape has >1 dimension. + """ + def __init__(self, max_rms: float): + super().__init__() + self.name = None + self.max_rms = max_rms + + + def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return _no_op(x) + else: + return ScaleLimiterFunction.apply(x, float(self.max_rms), + aux_loss_scale, self.name) + + +class CorrelationLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): + ctx.save_for_backward(x) + ctx.mask = mask + ctx.limit = limit + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 + x, = ctx.saved_tensors + mask = ctx.mask + aux_loss_scale = ctx.aux_loss_scale + (batch_size, seq_len, num_channels) = x.shape + + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + x_orig = x + + def norm(x: Tensor): + eps = 1.0e-20 + return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() + x = norm(x) + + if mask is not None: + mask = (~mask).to(x.dtype).unsqueeze(-1) + x = x * mask + + half_batch = batch_size // 2 + if half_batch <= 1: + # the reason we also return None if half_batch==1 is because of CR-CTC + # where they may really be duplicates + return None, None, None, None, None + + + #x = torch.cat((x, y), dim=-1) + C = x.shape[-1] # num_channels + x1, x2 = x[0::2], x[1::2] + x1 = x1.reshape(-1, C) + x2 = x2.reshape(-1, C) + + if mask is not None: + numel1 = mask[0::2].sum() + numel2 = mask[1::2].sum() + else: + numel1 = x1.shape[0] + numel2 = x2.shape[0] + + S1 = torch.matmul(x1.t(), x1) * (1. / numel1) + S2 = torch.matmul(x2.t(), x2) * (1. / numel2) + + # S1, S2: (N, N) where N = min(num_channels, max_channels) + correlation = (S1 * S2).mean() + loss = (correlation - ctx.limit).relu() + + if random.random() < 0.0001: + logging.info( + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, loss={loss}" + ) + + loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) + + + return x_orig.grad, None, None, None, None + + +class CorrelationLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if the input feature has a covariance matrix that is + too different from the identity matrix. limit=1/num_channels is the + smallest limit you can provide but the limit should be much larger than + this, like 1/sqrt(num_channels). + + Assumes input is (batch, seq, channel) + """ + def __init__(self, limit: float = 0.03): + super().__init__() + self.name = None + self.limit = limit + + + def forward(self, x: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: + # x should be: (batch, seq, channel) + # returns a scalar tensor that should be included in the loss function with: + # z = with_loss(z, ret, None) + # where z is any quantity that will be used in calculating the main loss. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return torch.tensor(0.0, device=x.device) + else: + return CorrelationLimiterFunction.apply(x, + aux_loss_scale, + float(self.limit), + mask, + self.name) + + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + + + +def torch_compile(fn, *args, **kwargs): + if hasattr(torch, 'compile'): + fn = torch.compile(fn, *args, **kwargs, dynamic=True, options={"shape_padding": True, "force_shape_pad": True}) + return fn + +def swashl(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 + +def swashr(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 + + +def swashl_and_deriv(x: Tensor): + x_offset = 4. * x - 4. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.00875 + return y, deriv + +def swashr_and_deriv(x: Tensor): + x_offset = 4. * x - 1. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.07831542175 + return y, deriv + + +class SwashL(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashl) + def forward(self, x: Tensor) -> Tensor: + """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + return self.func(x) + +class SwashR(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashr) + def forward(self, x: Tensor) -> Tensor: + """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + return self.func(x) + + + +class ActivationAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + forward_func: Any, + backward_func: Any, + ): + ctx.save_for_backward(x, weight, bias) + + ctx.backward_func = backward_func + + x = forward_func(x) + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias) = saved + + y, func_deriv = ctx.backward_func(x) + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + return x_deriv, weight_deriv, bias_deriv, None, None + + + +class ActivationAndLinear(torch.nn.Module): + """ + This merges an activation function followed by a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwashL, this will be + equivalent to: + nn.Sequential(SwashL(), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwashL, SwashR. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwashL", + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", l.bias) + + self.activation = activation + + assert activation in ["SwashL", "SwashR"] + if activation == "SwashL": + self.forward_func = torch_compile(swashl) + self.backward_func = torch_compile(swashl_and_deriv) + else: + self.forward_func = torch_compile(swashr) + self.backward_func = torch_compile(swashr_and_deriv) + + + def forward(self, x: Tensor): + if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + x = self.forward_func(x) + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.forward_func, + self.backward_func, + ) + + + +def _test_swashl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwashL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swashr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwashR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_activation_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + if True: + for activation in ["SwashL", "SwashR"]: + m1 = nn.Sequential( + SwashL() if activation == "SwashL" else SwashR(), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + ) + with torch.no_grad(): + m2.weight[:] = m1[1].weight + if bias: + m2.bias[:] = m1[1].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + print("grad1 = ", m1[1].weight.grad) + print("grad2 = ", m2.weight.grad) + + assert torch.allclose(m1[1].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[1].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwashL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +def _test_orthogonal_linear(): + m = OrthogonalLinear(128, 128) + m(torch.randn(30, 2, 128)) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_swashr_deriv() + _test_swashl_deriv() + _test_activation_and_linear() + _test_orthogonal_linear() diff --git a/egs/librispeech/ASR/zapformer/zapformer_utils.py b/egs/librispeech/ASR/zapformer/zapformer_utils.py new file mode 100644 index 0000000000..4b8b1dc8cd --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zapformer_utils.py @@ -0,0 +1,181 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import copy +import random +from typing import Optional, Tuple, Union, Any + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.get_autocast_gpu_dtype()) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.amp.autocast('cuda', enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + ctx.dtype = y.dtype + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ctx.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name=None): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() diff --git a/egs/librispeech/ASR/zapformer/zipformer.py b/egs/librispeech/ASR/zapformer/zipformer.py deleted file mode 120000 index a064749a48..0000000000 --- a/egs/librispeech/ASR/zapformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/zipformer.py \ No newline at end of file From 47f059efa0557d07a667e70f9811f17e73c25a55 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Mar 2026 16:54:57 +0800 Subject: [PATCH 0989/1191] Various bug fixes --- .../ASR/zapformer/attention_decoder.py | 1 + egs/librispeech/ASR/zapformer/decoder.py | 113 ++++++++++++++++++ egs/librispeech/ASR/zapformer/subsampling.py | 2 +- egs/librispeech/ASR/zapformer/train.py | 5 +- egs/librispeech/ASR/zapformer/zapformer.py | 12 +- .../ASR/zapformer/zapformer_modules.py | 2 +- 6 files changed, 124 insertions(+), 11 deletions(-) create mode 120000 egs/librispeech/ASR/zapformer/attention_decoder.py create mode 100644 egs/librispeech/ASR/zapformer/decoder.py diff --git a/egs/librispeech/ASR/zapformer/attention_decoder.py b/egs/librispeech/ASR/zapformer/attention_decoder.py new file mode 120000 index 0000000000..830180a0cd --- /dev/null +++ b/egs/librispeech/ASR/zapformer/attention_decoder.py @@ -0,0 +1 @@ +../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/decoder.py b/egs/librispeech/ASR/zapformer/decoder.py new file mode 100644 index 0000000000..fc6aec95e6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decoder.py @@ -0,0 +1,113 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + ) + with torch.no_grad(): + # and we will scale by 10 in forward. this is because with an optimizer that has weight decay, + # it's best if all the parameters have fairly similar dynamic range. + self.embedding.weight[:] *= 0.1 + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) * 20.0 + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + + return embedding_out diff --git a/egs/librispeech/ASR/zapformer/subsampling.py b/egs/librispeech/ASR/zapformer/subsampling.py index 3ec098bc20..45c97d468f 100644 --- a/egs/librispeech/ASR/zapformer/subsampling.py +++ b/egs/librispeech/ASR/zapformer/subsampling.py @@ -21,7 +21,7 @@ from typing import Tuple, Optional import torch -from zipformer_modules import ( +from zapformer_modules import ( ScaledLinear, SwashL, SwashR, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b91b52fdf9..79e6b3306a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -61,7 +61,6 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 -import optim import sentencepiece as spm import torch import torch.multiprocessing as mp @@ -916,12 +915,12 @@ def augmentation( feature_lens=feature_lens, ) - # note: ExpAugment() does *somewhat* assume that x consists of two copies of + # note: AlternatingSpecAugment() does *somewhat* assume that x consists of two copies of # the same data, but practically speaking the only important use this is put # to is that it chooses non-overlapping frequency regions to mask. it also # chooses non-overlapping time regions to mask, but this is not so important # since the time warping (if used) was done independently on the two copies. - spec_augment = ExpAugment() + spec_augment = AlternatingSpecAugment() features = spec_augment(features) return features diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 0539c6e990..c226e73914 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -27,7 +27,7 @@ from encoder_interface import EncoderInterface from zapformer_modules import ( ActivationAndLinear, - CausalSequeneNorm, + CausalSequenceNorm, CorrelationLimiter, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. OrthogonalLinear, @@ -1385,7 +1385,7 @@ def __init__(self, at low frequency. """ super().__init__() - self.weight = nn.Parameter(0.04 * torch.randn(num_heads, pos_dim, 2 * num_freqs)) + self.weight = nn.Parameter(0.04 * torch.randn(num_heads, pos_head_dim, 2 * num_freqs)) with torch.no_grad(): # initialize the weight in a low-pass way. I think this is not so critical # actually, it may not matter. @@ -1401,7 +1401,7 @@ def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: """ Compute and return unnormalized log scores for relative position. Args: - p: these are the position-queries, of shape (batch_size, num_heads, seq_len, pos_dim) + p: these are the position-queries, of shape (batch_size, num_heads, seq_len, pos_head_dim) (they are obtained via projection, just like the queries). left_context_len: length of left context, must be 0 for non-streaming forward and > 0 for streaming forward. Returns: @@ -1410,7 +1410,7 @@ def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: In non-streaming forward, dest_seq_len and src_seq_len are numerically equal to seq_len; in streaming forward, dest_seq_len is seq_len and src_seq_len is seq_len + left_context_len. """ - (batch_size, num_heads, seq_len, pos_dim) = p.shape + (batch_size, num_heads, seq_len, pos_head_dim) = p.shape freqs = self.freqs # base freqs t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=p.device) @@ -1422,10 +1422,10 @@ def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: basis = basis.reshape(basis.shape[0], -1) # (2 * seq_len + left_context_len - 1, 2 * num_freqs) x = torch.matmul(self.weight, basis.t()) - assert x.shape == (num_heads, pos_dim, 2 * seq_len + left_context_len - 1) + assert x.shape == (num_heads, pos_head_dim, 2 * seq_len + left_context_len - 1) # with seq_len2 = 2 * seq_len + left_context_len - 1, - # (batch, head, seq_len, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, seq_len, seq_len2) + # (batch, head, seq_len, pos_head_dim) x (1, head, pos_head_dim, seq_len2) -> (batch, head, seq_len, seq_len2) pos_weights = torch.matmul(p, x) # the following .as_strided() expression converts the last axis of pos_weights from relative diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index 77a3b12640..d98ec3c253 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -26,7 +26,7 @@ import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd - +from zapformer_utils import limit_param_value From bcf578fae13d9ae31e731fac51a6b22d160d59d2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Mar 2026 17:43:48 +0800 Subject: [PATCH 0990/1191] Bug fixes. --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 1 + egs/librispeech/ASR/zapformer/multicopy_dataset.py | 6 +++++- icefall/__init__.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 12a894e818..50b4ff7614 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -299,6 +299,7 @@ def train_dataloaders( # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. train = MulticopyDataset( + num_copies=self.args.num_copies, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, diff --git a/egs/librispeech/ASR/zapformer/multicopy_dataset.py b/egs/librispeech/ASR/zapformer/multicopy_dataset.py index 2e6f145690..a41e9b4a1a 100755 --- a/egs/librispeech/ASR/zapformer/multicopy_dataset.py +++ b/egs/librispeech/ASR/zapformer/multicopy_dataset.py @@ -108,7 +108,11 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]] # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) - cuts = cuts.repeat(times=self.num_copies) + if self.num_copies > 1: + cuts = cuts.repeat(times=self.num_copies) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. diff --git a/icefall/__init__.py b/icefall/__init__.py index 831d66f0a1..b1e4313e9b 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,6 +1,6 @@ # isort:skip_file -from . import checkpoint, decode, dist, env, utils, exp_augment +from . import checkpoint, decode, dist, env, utils from .byte_utils import ( byte_decode, From 85336279534801d7f5f0a2da36854027e8ff1260 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 30 Mar 2026 11:34:31 +0800 Subject: [PATCH 0991/1191] Make learning rate of scale be propto scale and make bias scale_limits 4 times larger. --- .../ASR/zapformer/batched_rubik.py | 19 ++++++++++--------- egs/librispeech/ASR/zapformer/rubik.py | 19 ++++++++++--------- egs/librispeech/ASR/zapformer/train.py | 1 - 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index e3ed9ecc04..abd22f3ae1 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -180,8 +180,6 @@ def cubic_decay_step(group, state, grad): cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion - min_scale, max_scale = group["scale_limits"] - try: stored_delta = state["delta"] except KeyError as e: @@ -283,6 +281,10 @@ def scaling_step(group, param, state, grad): lr = group["lr"] wd = group["wd"] + momentum = 0.95 + is_weight = grad.ndim >= 3 + min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] + if grad.ndim >= 3 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): delta = cubic_decay_step(group, state, grad) else: @@ -294,22 +296,19 @@ def scaling_step(group, param, state, grad): scale_grad_buf = state["scale_grad_buffer"] except: shape = [ param.shape[0] ] + [1] * (param.ndim - 1) - scale = torch.ones(*shape, device=grad.device) + scale = min_scale * torch.ones(*shape, device=grad.device) # initialize scale to min_scale scale_grad_buf = torch.zeros(*shape, device=grad.device) state["scale"] = scale state["scale_grad_buffer"] = scale_grad_buf - momentum = 0.95 - min_scale, max_scale = group["scale_limits"] - dims = list(range(1, param.ndim)) scale_grad = (grad * param.detach()).sum(dim=dims, keepdim=True) scale_grad_buf.mul_(momentum).add_(scale_grad) old_scale = scale.clone() + scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr) - scale.add_(scale_grad_buf.sign(), alpha=-lr) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale @@ -387,7 +386,8 @@ def __init__( beta2=0.98, wd=12, eps=1.0e-16, - scale_limits=(1.0, 4.0), + weight_scale_limits=(1.0, 4.0), + bias_scale_limits=(4.0, 16.0), ): defaults = dict( @@ -398,7 +398,8 @@ def __init__( beta2=beta2, eps=eps, wd=wd, - scale_limits=scale_limits, + weight_scale_limits=weight_scale_limits, + bias_scale_limits=bias_scale_limits, ) param_groups, parameters_names = self._get_names_of_parameters(params) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index cbb62a2ca7..78bc3b0e96 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -82,8 +82,6 @@ def cubic_decay_step(group, state, grad): cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion - min_scale, max_scale = group["scale_limits"] - try: stored_delta = state["delta"] except KeyError as e: @@ -178,6 +176,10 @@ def scaling_step(group, param, state, grad): lr = group["lr"] wd = group["wd"] + momentum = 0.95 + is_weight = grad.ndim >= 2 + min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] + if grad.ndim >= 2 and grad.numel() != max(grad.shape): delta = cubic_decay_step(group, state, grad) else: @@ -188,21 +190,18 @@ def scaling_step(group, param, state, grad): scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] except: - scale = torch.ones(1, device=grad.device) + scale = min_scale * torch.ones(1, device=grad.device) # initialize scale to min_scale scale_grad_buf = torch.zeros(1, device=grad.device) state["scale"] = scale state["scale_grad_buffer"] = scale_grad_buf - momentum = 0.95 - min_scale, max_scale = group["scale_limits"] - scale_grad = (grad * param.detach()).sum() scale_grad_buf.mul_(momentum).add_(scale_grad) old_scale = scale.clone() - scale.add_(scale_grad_buf.sign(), alpha=-lr) + scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale @@ -269,7 +268,8 @@ def __init__( beta2=0.98, wd=12, eps=1.0e-16, - scale_limits=(1.0, 4.0), + weight_scale_limits=(1.0, 4.0), + bias_scale_limits=(4.0, 16.0), ): defaults = dict( lr=lr, @@ -279,7 +279,8 @@ def __init__( beta2=beta2, eps=eps, wd=wd, - scale_limits=scale_limits, + weight_scale_limits=weight_scale_limits, + bias_scale_limits=bias_scale_limits, ) super().__init__(params, defaults) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 79e6b3306a..514b320ece 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1380,7 +1380,6 @@ def run(rank, world_size, args): cubic_decay_proportion=0.8, wd=18, beta1=0.995, - scale_limits=(1.0, 4.0), ) # hardcode batches per epoch for now. From 286d09a6c6b76b95b06eb6e45a5006240da3252c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 30 Mar 2026 13:05:30 +0800 Subject: [PATCH 0992/1191] Remove specifiable weight decay from rubik, merge into bias and weight scales. --- .../ASR/zapformer/batched_rubik.py | 137 ++---------------- egs/librispeech/ASR/zapformer/rubik.py | 26 ++-- 2 files changed, 32 insertions(+), 131 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index abd22f3ae1..642b750337 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -279,11 +279,13 @@ def min_sum_scale(x, y): def scaling_step(group, param, state, grad): lr = group["lr"] - wd = group["wd"] momentum = 0.95 is_weight = grad.ndim >= 3 min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] + # the "scale" is implicitly a scalar, even though it is learned in log space; apply scalar_scale to its + # learning rate. + scalar_scale = group["scalar_scale"] if grad.ndim >= 3 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): delta = cubic_decay_step(group, state, grad) @@ -307,13 +309,13 @@ def scaling_step(group, param, state, grad): scale_grad_buf.mul_(momentum).add_(scale_grad) old_scale = scale.clone() - scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr) + scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr * scalar_scale) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale - delta_scale = (scale_ratio * (1 - (lr * wd) ** 2)) - 1 + delta_scale = (scale_ratio * (1 - lr ** 2)) - 1 return param * delta_scale + scale * delta @@ -379,15 +381,15 @@ class BatchedRubik(BatchedOptimizer): def __init__( self, params, - lr=1e-03, + lr=1.2e-02, beta1=0.995, direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, - wd=12, eps=1.0e-16, - weight_scale_limits=(1.0, 4.0), - bias_scale_limits=(4.0, 16.0), + weight_scale_limits=(0.05, 0.25), + bias_scale_limits=(0.2, 1.0), + scalar_scale=0.075, ): defaults = dict( @@ -397,9 +399,9 @@ def __init__( cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, - wd=wd, weight_scale_limits=weight_scale_limits, bias_scale_limits=bias_scale_limits, + scalar_scale=scalar_scale, ) param_groups, parameters_names = self._get_names_of_parameters(params) @@ -547,7 +549,10 @@ def step(self, closure=None): cur_step = 0 if p.numel() == p.shape[0]: - p += adam_step(group, state, grad) + # "scalar_scale" the assumed parameter scale used for + # scalars, in this case it just acts as a multiplier on + # the learning rate. + p += group["scalar_scale"] * adam_step(group, state, grad) else: p += scaling_step(group, p.detach(), state, grad) @@ -597,11 +602,11 @@ def _test_batched_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.001 + lr = 0.015 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - optim = BatchedRubik(m.named_parameters(), lr=lr, direct=0.0, beta1=0.999) + optim = BatchedRubik(m.parameters(), lr=lr, direct=0.0, beta1=0.999) num_epochs = 180 @@ -666,116 +671,6 @@ def lr_lambda(current_step): logging.info(f"output_magnitudes = {output_magnitudes}") -def _test_muon(hidden_dim: int): - import timeit - - from muon import Muon - - E = 100 - B = 4 - T = 2 - logging.info("in test_muon") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - if True: - fix_random_seed(42) - Linear = torch.nn.Linear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - optim = Muon(m.parameters(), - lr=0.5e-03, - wd=12.0) - - num_epochs = 180 - # hardcode batches per epoch for now. - total_steps = num_epochs - constant_fraction = 0.25 - - def lr_lambda(current_step): - progress = current_step / total_steps - if progress < constant_fraction: - return 1.0 - else: - return (1.0 - progress) / (1.0 - constant_fraction) - - scheduler = LambdaLR(optim, lr_lambda) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(num_epochs): - scheduler.step() - - # if epoch == 100 and test in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 512 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - norm2 = '%.2e' % (m[1].weight**2).mean().sqrt().item() - norm3 = '%.2e' % (m[3].weight**2).mean().sqrt().item() - norm4 = '%.2e' % (m[5].weight**2).mean().sqrt().item() - - bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() - bias_norm2 = '%.2e' % (m[3].bias**2).mean().sqrt().item() - bias_norm3 = '%.2e' % (m[5].bias**2).mean().sqrt().item() - - lr = scheduler.get_last_lr()[0] - logging.info( - f"Test muon, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3,norm4}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" - ) - loss.log().backward() - optim.step() - optim.zero_grad() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Muon: time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - def _test_compute_scaled_prod3(): x = torch.randn(3, 16, 32) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 78bc3b0e96..75b61af36c 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -174,11 +174,14 @@ def min_sum_scale(x, y): def scaling_step(group, param, state, grad): lr = group["lr"] - wd = group["wd"] momentum = 0.95 is_weight = grad.ndim >= 2 min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] + # the "scale" is implicitly a scalar, even though it is learned in log space; apply scalar_scale to its + # learning rate. + scalar_scale = group["scalar_scale"] + if grad.ndim >= 2 and grad.numel() != max(grad.shape): delta = cubic_decay_step(group, state, grad) @@ -201,12 +204,12 @@ def scaling_step(group, param, state, grad): old_scale = scale.clone() - scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr) + scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr * scalar_scale) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale - delta_scale = (scale_ratio * (1 - (lr * wd) ** 2)) - 1 + delta_scale = (scale_ratio * (1 - (lr ** 2))) - 1 return param * delta_scale + scale * delta @@ -261,15 +264,15 @@ class Rubik(Optimizer): def __init__( self, params, - lr=1e-03, + lr=1.2e-02, beta1=0.995, direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, - wd=12, eps=1.0e-16, - weight_scale_limits=(1.0, 4.0), - bias_scale_limits=(4.0, 16.0), + weight_scale_limits=(0.05, 0.25), + bias_scale_limits=(0.2, 1.0), + scalar_scale=0.075, ): defaults = dict( lr=lr, @@ -278,9 +281,9 @@ def __init__( cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, - wd=wd, weight_scale_limits=weight_scale_limits, bias_scale_limits=bias_scale_limits, + scalar_scale=scalar_scale, ) super().__init__(params, defaults) @@ -319,7 +322,10 @@ def u(x): return x.unsqueeze(0) if p.numel() == 1: - p += adam_step(group, state, grad) + # "scalar_scale" the assumed parameter scale used for + # scalars, in this case it just acts as a multiplier on + # the learning rate. + p += group["scalar_scale"] * adam_step(group, state, grad) else: p += scaling_step(group, u(p.detach()), state, u(grad))[0] @@ -368,7 +374,7 @@ def _test_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.001 + lr = 0.015 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 From 85d6b325438d3c60143a45282ad8a2edd7b85094 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 30 Mar 2026 13:25:46 +0800 Subject: [PATCH 0993/1191] Initialize scales to the actual parameter scales. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 5 ++--- egs/librispeech/ASR/zapformer/rubik.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 642b750337..af7d80d120 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -297,9 +297,8 @@ def scaling_step(group, param, state, grad): scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] except: - shape = [ param.shape[0] ] + [1] * (param.ndim - 1) - scale = min_scale * torch.ones(*shape, device=grad.device) # initialize scale to min_scale - scale_grad_buf = torch.zeros(*shape, device=grad.device) + scale = (param ** 2).mean(dim=list(range(1, param.ndim)), keepdim=True).sqrt().clamp(min=min_scale, max=max_scale).to(torch.float) + scale_grad_buf = torch.zeros_like(scale) state["scale"] = scale state["scale_grad_buffer"] = scale_grad_buf diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 75b61af36c..1c3353ca78 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -193,8 +193,8 @@ def scaling_step(group, param, state, grad): scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] except: - scale = min_scale * torch.ones(1, device=grad.device) # initialize scale to min_scale - scale_grad_buf = torch.zeros(1, device=grad.device) + scale = (param ** 2).mean().sqrt().clamp(min=min_scale, max=max_scale).to(torch.float) + scale_grad_buf = torch.zeros_like(scale) state["scale"] = scale state["scale_grad_buffer"] = scale_grad_buf From f7ef39f9911a74a35c3007dd886f02a9b2c48008 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 30 Mar 2026 13:36:51 +0800 Subject: [PATCH 0994/1191] Remove weight decay arg. --- egs/librispeech/ASR/zapformer/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 514b320ece..1757b2b84e 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1378,7 +1378,6 @@ def run(rank, world_size, args): lr=params.base_lr, direct=0.15, cubic_decay_proportion=0.8, - wd=18, beta1=0.995, ) From 0eb40574644e6ebf42ef8381b3002b16671f9981 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 30 Mar 2026 15:17:02 +0800 Subject: [PATCH 0995/1191] Change BasisConv to normal convolution, with bias=False; still in parallel with WeightedMean. --- egs/librispeech/ASR/zapformer/train.py | 2 +- egs/librispeech/ASR/zapformer/zapformer.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 1757b2b84e..459b781592 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -251,7 +251,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--conv-params", type=str, - default="32,32,16,32", + default="31,31,15,31", help="Parameters per channel of convolution kernels", ) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index c226e73914..3ba91c5c6e 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1770,9 +1770,16 @@ def __init__( if not causal: - self.depthwise_conv = BasisConv(bottleneck_dim, - num_freqs=kernel_size*2, - params_per_channel=kernel_size) + assert kernel_size % 2 == 1 + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=False, + ) + else: self.depthwise_conv = nn.Conv1d( in_channels=bottleneck_dim, @@ -1780,7 +1787,7 @@ def __init__( groups=bottleneck_dim, kernel_size=kernel_size, padding=0, # will pad manually, on one side. - bias=True, + bias=False, ) self.left_pad = kernel_size - 1 @@ -1834,17 +1841,17 @@ def forward( wm = self.weighted_mean(x, src_key_padding_mask) + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) if self.causal: # Not support exporting a model for simulated streaming decoding assert not torch.jit.is_scripting() and not torch.jit.is_tracing() - x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) x_shape = x.shape x = torch.nn.functional.pad(x, (self.left_pad, 0)) x = self.depthwise_conv(x) assert x.shape == x_shape, (x.shape, x_shape) - x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) else: x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) x = x + wm # Add in the weighted-mean to the convolution; this adds extra power # because the utterances differ in length. From ef3233bd4baea9eb591f87835623fba29dfc0e0d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 31 Mar 2026 11:41:35 +0800 Subject: [PATCH 0996/1191] Halve bias_scale_limits from (0.2,1.0) to (0.1,0.5). --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/rubik.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index af7d80d120..4fbe5cd48e 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -387,7 +387,7 @@ def __init__( beta2=0.98, eps=1.0e-16, weight_scale_limits=(0.05, 0.25), - bias_scale_limits=(0.2, 1.0), + bias_scale_limits=(0.1, 0.5), scalar_scale=0.075, ): diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 1c3353ca78..1840e4c943 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -271,7 +271,7 @@ def __init__( beta2=0.98, eps=1.0e-16, weight_scale_limits=(0.05, 0.25), - bias_scale_limits=(0.2, 1.0), + bias_scale_limits=(0.1, 0.5), scalar_scale=0.075, ): defaults = dict( From 0dd7b9527f3fabb5508fbe4d6944ef85c72b7591 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 31 Mar 2026 13:14:48 +0800 Subject: [PATCH 0997/1191] Do not respect pairs of sequences in AlternatingSpecAugment, use invisible twins; select 0-or-1 offset randomly. --- .../ASR/zapformer/alternating_spec_augment.py | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 6bf6038254..93e4f0bb3d 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -12,9 +12,9 @@ class AlternatingSpecAugment(torch.nn.Module): from lhotse which is the same as the original SpecAugment). The main difference is in how it selects the regions to be masked, they are selected - for pairs of sequences in such a way that there tends to be a good amount of spacing between - masked regions; the masked regions never overlap and will never be extremely close to - each other. We also use a relatively large masked-fraction + in a way that usually ensures there is a good amount of space between successive masks. + We also use a relatively large temporal masked-fraction (max_frame_mask_fraction) + and have the number of masks be selected proportional to the utterance length. """ def __init__( self, @@ -158,12 +158,11 @@ def _mask_on_axis(self, def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_fraction, device) -> Tuple[Tuple,Tuple]: - # compute the start and end positions of masked regions. this will select mask positions - # that do not overlap. Return: (mask_starts, mask_ends). - - # we sample the masks for pairs of sequences. - B = (batch_size + 1) // 2 - # M is the number of masks we sample for each pair of sequences. + # we imagine there are "pairs of sequences" for historical reasons but one of each pair is not + # a real sequence. + B = batch_size + # M is the number of masks we sample for each "pair of sequences." (i.e. for each sequence and its + # imaginary twin) M = 2 * num_masks # "rlength" means relative length of each mask, i.e. relative to seq_len. the @@ -212,17 +211,12 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ assert mask_starts.shape == (B, M) and mask_ends.shape == (B, M) - # letting A,B be randomly 0 or 1 avoids any overall bias towards the start or end of the - # sequence in case the batch size is odd. - A = random.randint(0, 1) - B = (A + 1) % 2 - mask_starts1 = mask_starts[:, A::2] - mask_ends1 = mask_ends[:, A::2] - mask_starts2 = mask_starts[:, B::2] - mask_ends2 = mask_ends[:, B::2] - - mask_starts = torch.cat((mask_starts1, mask_starts2), dim=0)[:batch_size] - mask_ends = torch.cat((mask_ends1, mask_ends2), dim=0)[:batch_size] + # letting the start-position when we take alternating positions be + # randomly 0 or 1 avoids any overall bias towards the start or end of + # the sequence. + index = torch.randint(0, 2, (B,), device=device).unsqueeze(-1) + torch.arange(0, M, step=2, device=device) + mask_starts = torch.gather(mask_starts, dim=1, index=index) + mask_ends = torch.gather(mask_ends, dim=1, index=index) return mask_starts, mask_ends From 3a8d333b6a1d862a83e2a6f88ee5906c8d8eefc1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Apr 2026 11:17:28 +0800 Subject: [PATCH 0998/1191] Add more diagnostic code in test. --- egs/librispeech/ASR/zapformer/alternating_spec_augment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 93e4f0bb3d..b05da7b005 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -275,8 +275,11 @@ def _test_alternating_spec_augment(): frame_is_masked = features[:, :, 0] == features[:, :, -1] print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) + + print("mean frame_is_masked[per-frame][::10] = ", frame_is_masked.to(torch.float).mean(dim=0)[::10]) feature_is_masked = features[:, 0] == features[:, -1] print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) + print("mean feature_is_masked[per-freq] = ", feature_is_masked.to(torch.float).mean(dim=0)) From 0230f0b332bd59110f944d7e40116a39023e8767 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Apr 2026 14:29:50 +0800 Subject: [PATCH 0999/1191] Apply weighted_mean in self-attention also --- egs/librispeech/ASR/zapformer/zapformer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index c226e73914..2e54422ec2 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -543,6 +543,7 @@ def __init__( query_head_dim=query_head_dim, value_head_dim=value_head_dim, pos_head_dim=pos_head_dim, + causal=causal, ) feedforward_dim = embed_dim * feedforward_multiple @@ -897,8 +898,9 @@ def __init__( embed_dim: int, num_heads: int, query_head_dim: int, - pos_head_dim: int = 4, - value_head_dim: int = 12, + pos_head_dim: int , + value_head_dim: int, + causal: bool, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -938,6 +940,8 @@ def __init__( num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 ) + self.weighted_mean = WeightedMean(num_heads * value_head_dim, causal) # TODO: fix causal option + def forward( self, x_qkp: Tensor, @@ -1035,6 +1039,11 @@ def forward( v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) + + # v, g: (seq_len, batch_size, num_heads * value_head_dim) + + wm = self.weighted_mean(v, key_padding_mask) + v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) v = self.copy_v(v) value_head_dim = v.shape[-1] @@ -1055,6 +1064,7 @@ def forward( g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + v = v + wm v = v * self.sigmoid(g) v = self.out_proj(v) return v From dbe49d623eab086267287de8ff5a287122684a2b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Apr 2026 16:28:40 +0800 Subject: [PATCH 1000/1191] Bug fixes RE CV test and --max-duration/--num-copies types from 2236. --- .../ASR/zapformer/asr_datamodule.py | 4 +- egs/librispeech/ASR/zapformer/ctc_decode.py | 57 +++++++++++++++---- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 50b4ff7614..29842c72fc 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -115,7 +115,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): ) group.add_argument( "--max-duration", - type=int, + type=float, default=800.0, help="Maximum pooled recordings duration (seconds) in a " "single batch, including the --num-copies argument, so if --num-copies " @@ -212,7 +212,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--num-copies", - type=str, + type=int, default=4, help="The number of copies of each training example selected in each batch (they will be augmented " "differently). If you make num-copies larger there will be more steps per epoch so you should probably make " diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py index cbbc7313d1..963e4f2047 100755 --- a/egs/librispeech/ASR/zapformer/ctc_decode.py +++ b/egs/librispeech/ASR/zapformer/ctc_decode.py @@ -112,6 +112,7 @@ import logging import math import os +import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -120,7 +121,7 @@ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeech, GigaSpeech, AsrDataModule +from asr_datamodule import CommonVoice, LibriSpeech, GigaSpeech, AsrDataModule from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params @@ -189,7 +190,7 @@ ) -def asr_text_post_processing(text: str) -> str: # only used for gigaspeech +def giga_asr_text_post_processing(text: str) -> str: # only used for gigaspeech # 1. convert to uppercase text = text.upper() @@ -206,13 +207,27 @@ def asr_text_post_processing(text: str) -> str: # only used for gigaspeech return " ".join(remaining_words) -def post_processing( + +def giga_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = giga_asr_text_post_processing(" ".join(ref)).split() + new_hyp = giga_asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def cv_post_processing( results: List[Tuple[str, List[str], List[str]]], ) -> List[Tuple[str, List[str], List[str]]]: + def normalize(text): + return re.sub(r'[^\w\s]', '', text).upper() new_results = [] for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_ref = normalize(" ".join(ref)).split() + new_hyp = normalize(" ".join(hyp)).split() new_results.append((key, new_ref, new_hyp)) return new_results @@ -442,6 +457,13 @@ def get_parser(): help="""Skip scoring, but still save the ASR output (for eval sets).""", ) + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + add_model_arguments(parser) return parser @@ -857,8 +879,10 @@ def save_asr_output( recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - if params.giga: - results = post_processing(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) store_transcripts(filename=recogs_filename, texts=results) @@ -881,8 +905,10 @@ def save_wer_results( test_set_wers = dict() for key, results in results_dict.items(): - if params.giga: - results = post_processing(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -1238,8 +1264,17 @@ def main(): dev_cuts = gigaspeech.dev_cuts() giga_test_dl = asr_datamodule.test_dataloaders(test_cuts) giga_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) - test_sets += ["dev", "test"] - test_dl += [giga_test_dl, giga_dev_dl] + test_sets += ["giga-dev", "giga-test"] + test_dl += [giga_dev_dl, giga_test_dl] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + cv_test_dl = asr_datamodule.test_dataloaders(test_cuts) + cv_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["cv-dev", "cv-test"] + test_dl += [cv_dev_dl, cv_test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( From 0602f1f4e20fc4288914b67a609e817899d9aaae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Apr 2026 16:06:40 +0800 Subject: [PATCH 1001/1191] Add input sigmoid gating in self-attn. --- egs/librispeech/ASR/zapformer/zapformer.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 2e54422ec2..1d31284f61 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -929,11 +929,12 @@ def __init__( self.copy_pos_query = Identity() # value and gating in_proj. - self.vg_in_proj = ScaledLinear(embed_dim, 2 * num_heads * value_head_dim, + self.vg_in_proj = ScaledLinear(embed_dim, 3 * num_heads * value_head_dim, initial_scale=0.1, bias=True) self.copy_v = nn.Identity() # diagnostics. - self.sigmoid = nn.Sigmoid() + self.sigmoid_in = nn.Sigmoid() + self.sigmoid_out = nn.Sigmoid() # out proj for the value times gating. self.out_proj = ScaledLinear( @@ -1037,13 +1038,19 @@ def forward( elif random.random() < 0.001: self._print_attn_entropy(attn_weights) - - v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) - - # v, g: (seq_len, batch_size, num_heads * value_head_dim) + vg = self.vg_in_proj(x_vg) + N = vg.shape[-1] // 3 + v = vg[..., :N] + g = vg[..., N:] + if self.training: + # don't let the sigmoid values get too extreme, limit to -2..2. + g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) wm = self.weighted_mean(v, key_padding_mask) + g_in, g_out = g.chunk(2, dim=-1) + v = v * self.sigmoid_in(g_in) + v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) v = self.copy_v(v) value_head_dim = v.shape[-1] @@ -1059,13 +1066,10 @@ def forward( .view(seq_len, batch_size, num_heads * value_head_dim) ) - if self.training: - # don't let the sigmoid values get too extreme, limit to -2..2. - g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. v = v + wm - v = v * self.sigmoid(g) + v = v * self.sigmoid_out(g_out) v = self.out_proj(v) return v From 3041218b6deb4030b64ea6d9b4903994b2f30eb3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Apr 2026 16:47:27 +0800 Subject: [PATCH 1002/1191] Actually apply mask in weighted_mean; put sigmoid gating before weighted_mean. --- egs/librispeech/ASR/zapformer/zapformer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 1d31284f61..b33c342a7e 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1046,11 +1046,11 @@ def forward( # don't let the sigmoid values get too extreme, limit to -2..2. g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) - wm = self.weighted_mean(v, key_padding_mask) - g_in, g_out = g.chunk(2, dim=-1) v = v * self.sigmoid_in(g_in) + wm = self.weighted_mean(v, key_padding_mask, apply_mask=True) + v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) v = self.copy_v(v) value_head_dim = v.shape[-1] @@ -1671,7 +1671,8 @@ def __init__(self, def forward(self, x: Tensor, - src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + src_key_padding_mask: Optional[Tensor] = None, + apply_mask: bool = True) -> Tensor: """ Compute weighted mean. x: (time, batch, channel) @@ -1688,9 +1689,13 @@ def forward(self, # assume x already masked, if mask is in use. if src_key_padding_mask is not None: - num_frames = src_key_padding_mask.logical_not().to(torch.float).sum(dim=1) + mask = src_key_padding_mask.logical_not().to(torch.float) + num_frames = mask.sum(dim=1) num_frames = num_frames.unsqueeze(-1).to(torch.float) + if apply_mask: + x = x * mask.t().unsqueeze(-1) + # num_frames: (batch_size, 1) return x.mean(dim=0) * (T / num_frames) * self.weights else: @@ -1847,7 +1852,9 @@ def forward( x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - wm = self.weighted_mean(x, src_key_padding_mask) + wm = self.weighted_mean(x, + src_key_padding_mask, + apply_mask=False) # just applied it. if self.causal: # Not support exporting a model for simulated streaming decoding assert not torch.jit.is_scripting() and not torch.jit.is_tracing() From 4255790e5295b369ae023fdc54e5fb49477d57b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 2 Apr 2026 23:59:12 +0800 Subject: [PATCH 1003/1191] Increase value-head-dim from 64 to 98; reduce central num layers from 14 to 12. --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 1757b2b84e..3cb510daad 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -181,7 +181,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,14,8", + default="6,8,12,8", help="Number of zapformer encoder layers per stack, comma separated.", ) @@ -237,7 +237,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--value-head-dim", type=str, - default="64", + default="96", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) From 894149c9736e6fb9078978e53ad032b58f49f0fa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Apr 2026 15:17:21 +0800 Subject: [PATCH 1004/1191] Remove depthwise_conv.lr_cale = 0.66 --- egs/librispeech/ASR/zapformer/zapformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 06032f2d97..7378bb8894 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1810,7 +1810,6 @@ def __init__( ) self.left_pad = kernel_size - 1 - self.depthwise_conv.lr_scale = 0.66 # add average-of-all-frames to the "convolution."; it has extra power vs the convolution # because the num frames differs between utterances. self.weighted_mean = WeightedMean(bottleneck_dim, From 9485532060a34ecd36cc8d0c0b3a3b7e3ba5d9c4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Apr 2026 15:41:03 +0800 Subject: [PATCH 1005/1191] Restore depthwise_conv.lr_scale = 0.66 --- egs/librispeech/ASR/zapformer/zapformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 7378bb8894..a5ad2df8a8 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1809,6 +1809,7 @@ def __init__( bias=False, ) self.left_pad = kernel_size - 1 + self.depthwise_conv.lr_scale = 0.66 # add average-of-all-frames to the "convolution."; it has extra power vs the convolution # because the num frames differs between utterances. From e1ada1f51da22d86b564fffbded904f30719e660 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Apr 2026 14:52:22 +0800 Subject: [PATCH 1006/1191] Remove zapformer_denoise directory --- .../ASR/zapformer_denoise/asr_datamodule.py | 448 ------ .../ASR/zapformer_denoise/decode.py | 537 ------- .../zapformer_denoise/decode_gigaspeech.py | 1 - .../zapformer_denoise/encoder_interface.py | 1 - .../ASR/zapformer_denoise/export-onnx-ctc.py | 1 - .../export-onnx-streaming-ctc.py | 1 - .../export-onnx-streaming.py | 1 - .../ASR/zapformer_denoise/export-onnx.py | 1 - .../ASR/zapformer_denoise/export.py | 1 - .../ASR/zapformer_denoise/finetune.py | 1 - .../generate_averaged_model.py | 1 - .../ASR/zapformer_denoise/label_smoothing.py | 1 - .../ASR/zapformer_denoise/model.py | 388 ----- .../ASR/zapformer_denoise/optim.py | 1 - .../ASR/zapformer_denoise/pretrained.py | 1 - .../ASR/zapformer_denoise/scaling.py | 1 - .../zapformer_denoise/speech_recognition.py | 229 --- .../ASR/zapformer_denoise/subsampling.py | 297 ---- .../ASR/zapformer_denoise/test_scaling.py | 1 - .../ASR/zapformer_denoise/train.py | 1378 ----------------- .../ASR/zapformer_denoise/zapformer.py | 1344 ---------------- 21 files changed, 4635 deletions(-) delete mode 100755 egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py delete mode 100755 egs/librispeech/ASR/zapformer_denoise/decode.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/encoder_interface.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/export-onnx.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/export.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/finetune.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/label_smoothing.py delete mode 100755 egs/librispeech/ASR/zapformer_denoise/model.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/optim.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/pretrained.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/scaling.py delete mode 100755 egs/librispeech/ASR/zapformer_denoise/speech_recognition.py delete mode 100644 egs/librispeech/ASR/zapformer_denoise/subsampling.py delete mode 120000 egs/librispeech/ASR/zapformer_denoise/test_scaling.py delete mode 100755 egs/librispeech/ASR/zapformer_denoise/train.py delete mode 100644 egs/librispeech/ASR/zapformer_denoise/zapformer.py diff --git a/egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py b/egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py deleted file mode 100755 index 09513afbe0..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/asr_datamodule.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - K2SpeechRecognitionDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="""Used only when --mini-libri is False.When enabled, - use 960h LibriSpeech. Otherwise, use 100h subset.""", - ) - group.add_argument( - "--mini-libri", - type=str2bool, - default=False, - help="True for mini librispeech", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=[], - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_clean_5_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get train-clean-5 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" - ) - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_2_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get dev-clean-2 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) - - @lru_cache() - def gigaspeech_subset_small_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") - - @lru_cache() - def gigaspeech_dev_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") - - @lru_cache() - def gigaspeech_test_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer_denoise/decode.py b/egs/librispeech/ASR/zapformer_denoise/decode.py deleted file mode 100755 index dedf092b82..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/decode.py +++ /dev/null @@ -1,537 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule - -from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--num-steps", - type=int, - default=8, - help="""The number of time-steps in denoising decoding.""" - ) - - parser.add_argument( - "--eps", - type=float, - default=1.0e-04, - help="""The t value that we start from with pure noise.""" - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--skip-scoring", - type=str2bool, - default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, -) -> List[List[str]]: - """Decode one batch and return the result as a list of sentences - (each sentence is a list of words). - - Args: - params: - The return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - The return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - Returns: - Return the decoding result as a list of list of strings (words), i.e. - a list of sentences. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - tokens = model.infer(feature, feature_lens, params.eps, params.num_steps) # list of lists of int - - hyps = [ sp.decode(t).split() for t in tokens ] # list of lists of str - - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, -) -> List[Tuple[str, List[str], List[str]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Returns list of tuples (cut_id, ref_transcript, hyp_transcript) - with types (str, List[str], List[str]). - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - - log_interval = 10 - - - results = [ ] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, - model=model, - sp=sp, - batch=batch, - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_asr_output( - params: AttributeDict, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]] -): - """ - Save text produced by ASR. - """ - recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - - results = sorted(results) - store_transcripts(filename=recogs_filename, texts=results) - - logging.info(f"The transcripts are stored in {recogs_filename}") - - -def save_wer_results( - params: AttributeDict, - test_set_name: str, - results: List[Tuple[str, List[str], List[str], Tuple]], -): - """ - Save WER and per-utterance word alignments. - """ - - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w", encoding="utf8") as fd: - wer = write_error_stats( - fd, f"{test_set_name}", results, enable_log=True - ) - logging.info(f"Wrote detailed error stats to {errs_filename}") - - wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - - with open(wer_filename, "w", encoding="utf8") as fd: - print(f"{wer}", file=fd) - - s = f"\nFor {test_set_name}, WER is {wer}" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - # enable AudioCache - set_caching_enabled(True) # lhotse - - params.res_dir = params.exp_dir / "decode" - - if params.iter > 0: - params.suffix = f"iter-{params.iter}_avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" - - params.suffix += f"_{params.num_steps}step" - - if params.use_averaged_model: - params.suffix += "_use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - dev_clean_cuts = librispeech.dev_clean_cuts() - dev_other_cuts = librispeech.dev_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) - dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) - - test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] - test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - ) - - save_asr_output( - params=params, - test_set_name=test_set, - results=results, - ) - - if not params.skip_scoring: - save_wer_results( - params=params, - test_set_name=test_set, - results=results, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py deleted file mode 120000 index 63b0ef617b..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/decode_gigaspeech.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/encoder_interface.py b/egs/librispeech/ASR/zapformer_denoise/encoder_interface.py deleted file mode 120000 index aa5d0217a8..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py deleted file mode 120000 index dc14e93e75..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py deleted file mode 120000 index 3baa2b673c..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py deleted file mode 120000 index d18cb9a9a1..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export-onnx.py b/egs/librispeech/ASR/zapformer_denoise/export-onnx.py deleted file mode 120000 index f343cf7027..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/export.py b/egs/librispeech/ASR/zapformer_denoise/export.py deleted file mode 120000 index 1a126ab695..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/export.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/finetune.py b/egs/librispeech/ASR/zapformer_denoise/finetune.py deleted file mode 120000 index 0e9e7989b9..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/finetune.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py b/egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py deleted file mode 120000 index b65513a058..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/generate_averaged_model.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/label_smoothing.py b/egs/librispeech/ASR/zapformer_denoise/label_smoothing.py deleted file mode 120000 index 3690afff9d..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/model.py b/egs/librispeech/ASR/zapformer_denoise/model.py deleted file mode 100755 index 4452d5a61f..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/model.py +++ /dev/null @@ -1,388 +0,0 @@ -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple, List - -import k2 -import torch -import logging -import torch.nn as nn -from torch import Tensor -from scaling import ScaledLinear, convert_num_channels, SwashR -import math -from icefall.utils import make_pad_mask, time_warp - - - -class DenoisingAsrModel(nn.Module): - def __init__( - self, - #speech_embed: nn.Module, - encoder: nn.Module, - encoder_dim: int, - text_embed_dim: int, - vocab_size: int, - time_embed_dim: int, - ): - """ - TODO - """ - super().__init__() - - self.speech_scale = 0.1 - self.encoder = encoder - self.encoder_dim = encoder_dim - - # s is the time value for the speech, 0 <= s <= 1. - # t is the time value for the symbols, 0 <= t <= 1. - self.time_embed_dim = time_embed_dim - self.st_embed = nn.Sequential( - nn.Linear(time_embed_dim * 2, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim) - ) - - # randomly initialize text embedding and do not train it. - text_embed_scale = 0.25 # this will ensure that later steps still "matter". - self.text_embed = FixedEmbedding(vocab_size, text_embed_dim, scale=text_embed_scale) - - self.text_in_proj = nn.Linear(text_embed_dim, encoder_dim) - self.text_out_proj = nn.Linear(encoder_dim, text_embed_dim) - - # for now just hardcode - speech_channels = 80 - speech_subsample = 4 - self.speech_out_proj = nn.Linear(encoder_dim, - speech_channels * speech_subsample) - - self.speech_in_proj = nn.Linear(speech_channels * speech_subsample, - encoder_dim) - - - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: torch.Tensor, - y_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A Tensor of dtype long, indexed [utt][symbol], padded with symbol 0 - on the right. There is no BOS or EOS symbol. - - Returns: - Returns flow-matching loss values for symbols and speech respectively. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - batch_size = x.shape[0] - assert x.shape[0] == x_lens.shape[0] == y.shape[0], (x.shape, x_lens.shape, y.shape) - - s = torch.empty(batch_size, device=x.device).uniform_(0.5, 1.0) # time-value for speech. only have >= 0.5 - t = torch.rand(batch_size, device=x.device) # time-value for text. - - - st = self.st_embed(torch.cat((timestep_embedding(s, self.time_embed_dim), - timestep_embedding(t, self.time_embed_dim)), dim=1)) - # st: (batch_size, time_embed_dim) - - (batch_size, speech_seq_len, num_freqs) = x.shape - - device = x.device - x1 = x * self.speech_scale # scale log-mels by 0.1 to be better matched to normal distribution. - x0 = torch.randn_like(x1) - xs = (x1 * s[:, None, None]) + (x0 * (1 - s[:, None, None])) - # x1, x0, xs: (batch_size, seq_len, 80) - xV = x1 - x0 # xV means x velocity. (batch_size, speech_seq_len, 80) - - padding = (4 - (speech_seq_len % 4)) % 4 - xs = torch.nn.functional.pad(xs, (0, 0, 0, padding)) - xs = xs.reshape(batch_size, -1, 4 * num_freqs) - xs_embed = self.speech_in_proj(xs) - x_lens_embed = x_lens // 4 - - xs_embed = xs_embed.permute(1, 0, 2) # (embed_seq_len, batch_size, encoder_dim) - embed_seq_len = xs_embed.shape[0] - - with torch.amp.autocast('cuda', enabled=False): - y = randomly_pad_to_lengths(y, y_lens, torch.minimum(x_lens_embed, y_lens + y_lens // 4), embed_seq_len) - # now y: (batch_size, seq_len) - y1 = self.text_embed(y) - # now y1: (batch_size, seq_len, text_embed_dim) - y0 = torch.randn_like(y1) - yt = (y1 * t[:, None, None]) + (y0 * (1 - t[:, None, None])) - # yt: (batch_size, seq_len, text_embed_dim) - yt_embed = self.text_in_proj(yt).permute(1, 0, 2) # (embed_seq_len, batch_size, encoder_dim) - yV = y1 - y0 # yV means y velocity. (batch_size, embed_seq_len, text_embed_dim) - - encoder_in = xs_embed + yt_embed - - src_key_padding_mask = torch.arange(0, embed_seq_len, device=x.device) >= x_lens_embed.unsqueeze(-1) # (batch-size, max_x_len) - - encoder_out = self.encoder(encoder_in, st, x_lens_embed, src_key_padding_mask) - (embed_seq_len, batch_size, _encoder_dim) = encoder_out.shape - - xU = self.speech_out_proj(encoder_out) - xU = xU.permute(1, 0, 2).reshape(batch_size, embed_seq_len * 4, -1) - xU = xU[:, :speech_seq_len] # (batch_size, speech_seq_len, 80) - - # don't use x_mask in training, this will simplify inference. - # x_mask = (torch.arange(0, speech_seq_len, device=x.device) < x_lens.unsqueeze(-1)).unsqueeze(-1) - # x_mask: # (batch-size, speech_seq_len, 1). - - x_loss = ((xV - xU) ** 2).mean(dim=-1).sum() - - yU = self.text_out_proj(encoder_out) - yU = yU.permute(1, 0, 2) # (batch_size, embed_seq_len, text_embed_dim) - - #y_mask = torch.logical_not(src_key_padding_mask).unsqueeze(-1) - y_loss = ((yV - yU) ** 2).mean(dim=-1).sum() - - return x_loss, y_loss # speech_loss, text_loss - - - def infer( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - eps: float, - num_steps: int, - ) -> List[List[int]]: - """ - Does inference. Starting from random noise representing the text, does inference - for a number of steps and then converts the text representation to integers. - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - eps: - The 't' value to start inference from, e.g. 1.0e-04 - num_steps: - The number of inference steps to use. - - Returns: - Returns the inference result as a list of lists of symbols, with blanks (symbol zero) - removed. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - batch_size = x.shape[0] - assert x.shape[0] == x_lens.shape[0] - - s = torch.ones(batch_size, device=x.device) # time-value for speech is 1.0 throughout. - xs = x * self.speech_scale # scale log-mels by 0.1 to be better matched to normal distribution. - # xs is the same as x1, because s == 1.0, in inference there is no noise on the speech. - (batch_size, speech_seq_len, num_freqs) = xs.shape - padding = (4 - (speech_seq_len % 4)) % 4 - xs = torch.nn.functional.pad(xs, (0, 0, 0, padding)) - xs = xs.reshape(batch_size, -1, 4 * num_freqs) - xs_embed = self.speech_in_proj(xs) - x_lens_embed = x_lens // 4 - xs_embed = xs_embed.permute(1, 0, 2) # (embed_seq_len, batch_size, encoder_dim) - (embed_seq_len, batch_size, encoder_dim) = xs_embed.shape - src_key_padding_mask = torch.arange(0, embed_seq_len, device=x.device) >= x_lens_embed.unsqueeze(-1) # (batch-size, max_x_len) - text_embed_dim = self.text_embed.weight.shape[1] - - delta_t = (1.0 - eps) / num_steps - - yt = torch.randn(embed_seq_len, batch_size, text_embed_dim, device=x.device) # start with noise at t ~ 0 - - for step in range(num_steps): - t = torch.full((batch_size,), eps + step * delta_t, device=x.device) # time-value for text. - st = self.st_embed(torch.cat((timestep_embedding(s, self.time_embed_dim), - timestep_embedding(t, self.time_embed_dim)), dim=1)) - # st: (batch_size, time_embed_dim) - - - yt_embed = self.text_in_proj(yt) # (embed_seq_len, batch_size, encoder_dim) - encoder_in = xs_embed + yt_embed - encoder_out = self.encoder(encoder_in, st, x_lens_embed, src_key_padding_mask) - yU = self.text_out_proj(encoder_out) - - yt = yt + yU * delta_t - - - yt = yt.permute(1, 0, 2) # (batch_size, seq_len, text_embed_dim) - tokens, residual = find_closest_tokens(yt, self.text_embed.weight) - - logging.info(f"Avg residual is {residual}") - - tokens = tokens.tolist() - # remove blanks. - tokens = [ [ s for s in sent if s != 0 ] for sent in tokens ] - - return tokens - - - -class FixedEmbedding(nn.Module): - def __init__(self, vocab_size: int, embed_dim: int, scale: float = 1.0): - super().__init__() - self.register_buffer('weight', scale * torch.randn(vocab_size, embed_dim), - persistent=True) - - def forward(self, y: Tensor): - y_shape = y.shape - ans = torch.index_select(self.weight, 0, y.flatten()) - return ans.reshape(*y_shape, -1) - - - -def find_closest_tokens(y: Tensor, weights: Tensor) -> Tuple[Tensor, Tensor]: - """ - Find closest token indexes to embedding vectors. - Args: - y: (..., embed_dim), the embeddings to match to weights. - weights: (num_tokens, embed_dim), the embedding vectors for each token. - - Returns: (tokens, avg_residual) - tokens: (...), a LongTensor containing the indexes of the closest tokens - avg_residual: a LongTensor containing the average difference (rms of elements) - between embeddings and weights. - """ - yy = (y ** 2).sum(dim=-1) # (...) - ww = (weights ** 2).sum(dim=-1) # (num_tokens,) - yw = torch.matmul(y, weights.t()) # (..., num_tokens) - # (y - w) ** 2 = y**2 + w**2 - 2 yw - - residuals = yy.unsqueeze(-1) + ww - 2 * yw - residuals, tokens = torch.min(residuals, dim=-1) - - embed_dim = weights.shape[1] - return tokens, (residuals.mean() / embed_dim).sqrt() - - - -def timestep_embedding(timesteps, dim, max_period=10000): - """Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - -def randomly_pad_to_lengths(y: Tensor, - y_lens: Tensor, - x_lens: Tensor, - max_x_len: int): - """ - Randomly insert blanks (symbol 0) into the symbol-sequences in y, with lengths y_lens, so that - they have lengths x_lens. All tensor are LongTensors (dtype torch.long) - Args: - y: (batch_size, max_y_len): the symbols; all positions less than the corresponding y_lens value - are expected to be nonzero. - y_lens: the lengths of the sequences in y, we expect that 1 <= y_lens <= max_y_len - x_lens: the lengths of the sequences we want to pad to, we expect that y_lens <= x_lens <= max_x_len. - """ - # checking that each y is not longer than corresponding x. - debug = True #(__name__ == '__main__') - length_diff = x_lens - y_lens - if debug: - assert length_diff.min() >= 0 - - (batch_size, max_y_len) = y.shape - - - y_mask = torch.arange(0, max_y_len + 1, device=y.device) >= y_lens.unsqueeze(-1) # (batch-size, max_y_len) - # y_mask is True for masked, i.e. non-valid, positions - - # cut_points are points at which we divide up the interval [0..y_len-x_len] which is - # the amount by which we want to pad. We want to get y_len + 1 "padding lengths" that - # sum to y_len-x_len. We get these by taking the numbers: [ 0, , 1 , 1... ], - # multiplying by (y_len-x_len), so we have: [ 0, , y_len-x_len, y_len-x_len.. ], - # and take the differences between each one and the next, so we get: - # [ , 0, 0, ... ] and the counts add up to y_len-x_len. - # - cut_points = torch.rand(batch_size, max_y_len + 2, device=y.device) - cut_points[:, 1:].masked_fill_(y_mask, 1.0) - cut_points[:, 0] = 0.0 - cut_points = cut_points * length_diff.unsqueeze(-1) - cut_points = cut_points.sort(dim=1)[0] - cut_points = cut_points.round().to(torch.long) - num_pad = cut_points[:, 1:] - cut_points[:, :-1] - - - - num_symbols = torch.empty(batch_size, 2 * max_y_len, device=y.device, dtype=torch.long) - num_symbols[:, 1::2] = (1 - y_mask[:, :-1].to(torch.long)) # the actual symbols have length 1. - num_symbols[:, 0:-1:2] = num_pad[:, :-1] # assign the number of padding symbols for each position. - # we don't need the last padding length, it doesn't determine any symbol position. - - symbol_positions = num_symbols.cumsum(dim=1) - symbol_positions = symbol_positions[:, 0::2] - - # the "+ 1" is because the symbol_positions will actually contain, in the padding - # positions, a number equal to the corresponding values in x_lens; and this may - # be out of range in the scatter_ unless we add one padding element. - padded_symbols = torch.zeros(batch_size, max_x_len + 1, device=y.device, dtype=torch.long) - padded_symbols.scatter_(dim=1, index=symbol_positions, src=y) - padded_symbols = padded_symbols[:, :-1] # remove the one padding position - x_mask = torch.arange(0, max_x_len, device=y_lens.device) < x_lens.unsqueeze(-1) - if debug: - assert torch.all(padded_symbols == padded_symbols * x_mask) - return padded_symbols - - -def _test_find_closest_tokens(): - vocab_size = 10 - embed_dim = 30 - text_embed = FixedEmbedding(vocab_size, embed_dim) - tokens = torch.randint(0, vocab_size, (3, 4), dtype=torch.long) - - embeddings = text_embed(tokens) - embeddings = embeddings + 0.05 * torch.randn_like(embeddings) - - tokens2, residual = find_closest_tokens(embeddings, text_embed.weight) - print("Residual = ", residual) # should be around 0.05. - assert torch.all(tokens2 == tokens) - - -def _test_randomly_distribute_labels(): - y = torch.tensor([ [ 1, 2, 3, 4 ], [ 5, 6, 7, 0 ], [ 8, 9, 0, 0 ] ]) - y_lens = torch.tensor([ 4, 3, 2 ] ) - x_lens = torch.tensor([ 8, 6, 5 ]) - max_x_len = 7 - y = randomly_pad_to_lengths(y, y_lens, x_lens, max_x_len) - print("y_padded = ", y) - - - - -if __name__ == '__main__': - _test_find_closest_tokens() - for _ in range(10): - _test_randomly_distribute_labels() diff --git a/egs/librispeech/ASR/zapformer_denoise/optim.py b/egs/librispeech/ASR/zapformer_denoise/optim.py deleted file mode 120000 index 207eecfcda..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/optim.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/pretrained.py b/egs/librispeech/ASR/zapformer_denoise/pretrained.py deleted file mode 120000 index 70ad71ffc6..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/scaling.py b/egs/librispeech/ASR/zapformer_denoise/scaling.py deleted file mode 120000 index 58e4b0a0fe..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/speech_recognition.py b/egs/librispeech/ASR/zapformer_denoise/speech_recognition.py deleted file mode 100755 index dd069cf3da..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/speech_recognition.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Callable, Dict, List, Union - -import torch -from torch.utils.data.dataloader import DataLoader, default_collate - -from lhotse import validate -from lhotse.cut import CutSet -from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures -from lhotse.utils import compute_num_frames, ifnone -from lhotse.workarounds import Hdf5MemoryIssueFix - - -class K2SpeechRecognitionDataset(torch.utils.data.Dataset): - """ - The PyTorch Dataset for the speech recognition task using k2 library. - - This dataset expects to be queried with lists of cut IDs, - for which it loads features and automatically collates/batches them. - - To use it with a PyTorch DataLoader, set ``batch_size=None`` - and provide a :class:`SimpleCutSampler` sampler. - - Each item in this dataset is a dict of: - - .. code-block:: - - { - 'inputs': float tensor with shape determined by :attr:`input_strategy`: - - single-channel: - - features: (B, T, F) - - audio: (B, T) - - multi-channel: currently not supported - 'supervisions': [ - { - 'sequence_idx': Tensor[int] of shape (S,) - 'text': List[str] of len S - - # For feature input strategies - 'start_frame': Tensor[int] of shape (S,) - 'num_frames': Tensor[int] of shape (S,) - - # For audio input strategies - 'start_sample': Tensor[int] of shape (S,) - 'num_samples': Tensor[int] of shape (S,) - - # Optionally, when return_cuts=True - 'cut': List[AnyCut] of len S - } - ] - } - - Dimension symbols legend: - * ``B`` - batch size (number of Cuts) - * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) - * ``T`` - number of frames of the longest Cut - * ``F`` - number of features - - The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. - """ - - def __init__( - self, - return_cuts: bool = False, - cut_transforms: List[Callable[[CutSet], CutSet]] = None, - input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, - input_strategy: BatchIO = PrecomputedFeatures(), - ): - """ - k2 ASR IterableDataset constructor. - - :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut - objects used to create that batch. - :param cut_transforms: A list of transforms to be applied on each sampled batch, - before converting cuts to an input representation (audio/features). - Examples: cut concatenation, noise cuts mixing, etc. - :param input_transforms: A list of transforms to be applied on each sampled batch, - after the cuts are converted to audio/features. - Examples: normalization, SpecAugment, etc. - :param input_strategy: Converts cuts into a collated batch of audio/features. - By default, reads pre-computed features from disk. - """ - super().__init__() - # Initialize the fields - self.return_cuts = return_cuts - self.cut_transforms = ifnone(cut_transforms, []) - self.input_transforms = ifnone(input_transforms, []) - self.input_strategy = input_strategy - - # This attribute is a workaround to constantly growing HDF5 memory - # throughout the epoch. It regularly closes open file handles to - # reset the internal HDF5 caches. - self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) - - def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: - """ - Return a new batch, with the batch size automatically determined using the constraints - of max_duration and max_cuts. - """ - validate_for_asr(cuts) - - self.hdf5_fix.update() - - # Sort the cuts by duration so that the first one determines the batch time dimensions. - cuts = cuts.sort_by_duration(ascending=False) - - if self.cut_transforms: - orig_cuts = cuts - - cuts = cuts.repeat(times=2) - - for tnfm in self.cut_transforms: - cuts = tnfm(cuts) - - cuts = orig_cuts + cuts - num_copies = 3 - else: - num_copies = 1 - - - # Get a tensor with batched feature matrices, shape (B, T, F) - # Collation performs auto-padding, if necessary. - input_tpl = self.input_strategy(cuts) - if len(input_tpl) == 3: - # An input strategy with fault tolerant audio reading mode. - # "cuts" may be a subset of the original "cuts" variable, - # that only has cuts for which we successfully read the audio. - inputs, _, cuts = input_tpl - else: - inputs, _ = input_tpl - - # Get a dict of tensors that encode the positional information about supervisions - # in the batch of feature matrices. The tensors are named "sequence_idx", - # "start_frame/sample" and "num_frames/samples". - supervision_intervals = self.input_strategy.supervision_intervals(cuts) - - # Apply all available transforms on the inputs, i.e. either audio or features. - # This could be feature extraction, global MVN, SpecAugment, etc. - segments = torch.stack(list(supervision_intervals.values()), dim=1) - for tnfm in self.input_transforms: - inputs = tnfm(inputs, supervision_segments=segments) - - batch = { - "inputs": inputs, - "num_copies": num_copies, - "supervisions": default_collate( - [ - { - "text": supervision.text, - } - for sequence_idx, cut in enumerate(cuts) - for supervision in cut.supervisions - ] - ), - } - # Update the 'supervisions' field with sequence_idx and start/num frames/samples - batch["supervisions"].update(supervision_intervals) - if self.return_cuts: - batch["supervisions"]["cut"] = [ - cut for cut in cuts for sup in cut.supervisions - ] - - has_word_alignments = all( - s.alignment is not None and "word" in s.alignment - for c in cuts - for s in c.supervisions - ) - if has_word_alignments: - # TODO: might need to refactor BatchIO API to move the following conditional logic - # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), - # that returns either num_frames or num_samples depending on the strategy). - words, starts, ends = [], [], [] - frame_shift = cuts[0].frame_shift - sampling_rate = cuts[0].sampling_rate - if frame_shift is None: - try: - frame_shift = self.input_strategy.extractor.frame_shift - except AttributeError: - raise ValueError( - "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " - ) - for c in cuts: - for s in c.supervisions: - words.append([aliword.symbol for aliword in s.alignment["word"]]) - starts.append( - [ - compute_num_frames( - aliword.start, - frame_shift=frame_shift, - sampling_rate=sampling_rate, - ) - for aliword in s.alignment["word"] - ] - ) - ends.append( - [ - compute_num_frames( - aliword.end, - frame_shift=frame_shift, - sampling_rate=sampling_rate, - ) - for aliword in s.alignment["word"] - ] - ) - batch["supervisions"]["word"] = words - batch["supervisions"]["word_start"] = starts - batch["supervisions"]["word_end"] = ends - - return batch - - -def validate_for_asr(cuts: CutSet) -> None: - validate(cuts) - tol = 2e-3 # 1ms - for cut in cuts: - for supervision in cut.supervisions: - assert supervision.start >= -tol, ( - f"Supervisions starting before the cut are not supported for ASR" - f" (sup id: {supervision.id}, cut id: {cut.id})" - ) - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - # - # 'supervision.end' is end of supervision inside the Cut - assert supervision.end <= cut.duration + tol, ( - f"Supervisions ending after the cut " - f"are not supported for ASR" - f" (sup id: {supervision.id}, cut id: {cut.id})" - ) diff --git a/egs/librispeech/ASR/zapformer_denoise/subsampling.py b/egs/librispeech/ASR/zapformer_denoise/subsampling.py deleted file mode 100644 index 03e0319feb..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/subsampling.py +++ /dev/null @@ -1,297 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import Tuple, Optional - -import torch -from scaling import ( - ScaleLimiter, - ScaledLinear, - ExpNorm, - Dropout3, - FloatLike, - ScaledConv2d, - ScaleGrad, - ScheduledFloat, - SwashL, - SwashR, - Whiten, -) -from torch import Tensor, nn - - -class ConvNeXt(nn.Module): - """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf - """ - - def __init__( - self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - ): - super().__init__() - self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=self.padding, - ) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1, - ) - - self.activation = SwashL() - - self.pointwise_conv2 = nn.Conv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - ) - - - def forward( - self, x: Tensor, - ) -> Tensor: - """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. - """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - x = bypass + x - - return x - - def streaming_forward( - self, - x: Tensor, - cached_left_pad: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) - - Returns: - - The returned value has the same shape as x. - - Updated cached_left_pad. - """ - padding = self.padding - - # The length without right padding for depth-wise conv - T = x.size(2) - padding[0] - - bypass = x[:, :, :T, :] - - # Pad left side - assert cached_left_pad.size(2) == padding[0], ( - cached_left_pad.size(2), - padding[0], - ) - x = torch.cat([cached_left_pad, x], dim=2) - # Update cached left padding - cached_left_pad = x[:, :, T : padding[0] + T, :] - - # depthwise_conv - x = torch.nn.functional.conv2d( - x, - weight=self.depthwise_conv.weight, - bias=self.depthwise_conv.bias, - padding=(0, padding[1]), - groups=self.depthwise_conv.groups, - ) - x = self.pointwise_conv1(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - x = bypass + x - return x, cached_left_pad - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite - """ - assert in_channels >= 7 - self.in_channels = in_channels - super().__init__() - - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ScaleGrad(0.2), - SwashR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - SwashR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - padding=0, - ), - SwashR(), - ) - - - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - # (in_channels-3)//4 - self.out_width = (in_channels-3) // 4 - self.layer3_channels = layer3_channels - - # scale it up a bit, else the output is quite small. - self.out = ScaledLinear(self.out_width * layer3_channels, out_channels, - initial_scale=4.0) - - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), - prob=(0.025, 0.25), - grad_scale=0.02, - ) - - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = ExpNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - def pad(self, x: torch.Tensor) -> Tensor: - (N, T, idim) = x.shape - - - right_pad = (4 * ((T + 3) // 4)) - T - # first, pad to be a multiple of 4 frames. this is so we can later reconstruct at - # least the original number of frames. - - # next, we have to add 5 frames in order to get, finally (T + right_pad) // 4 frames. - left_pad = 3 - right_pad = 2 + right_pad - return torch.nn.functional.pad(x, (0, 0, left_pad, right_pad)) - - - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, (T-3)//4, odim) - - output lengths, of shape (batch_size,) - """ - # On entry, x is (batch_size, time, ideim) - x = self.pad(x) - # define x shape now as (N, T, idim) with T being the padded shape. - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, (T-5)//4, (idim-3)//4) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, (T-5)//4, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, (T-5)//4, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - - # the "+ 3" reflects the rounding-up-to-a-multiple-of-4 that we do at - # the start of self.pad(). We would, without self.pad() need to have a - # "-5" here and the adding 5 frames in self.pad() cancels that out. - if torch.jit.is_scripting() or torch.jit.is_tracing(): - x_lens = (x_lens + 3) // 4 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = (x_lens + 3) // 4 - assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) - - return x, x_lens diff --git a/egs/librispeech/ASR/zapformer_denoise/test_scaling.py b/egs/librispeech/ASR/zapformer_denoise/test_scaling.py deleted file mode 120000 index b776da79a1..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/test_scaling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer_denoise/train.py b/egs/librispeech/ASR/zapformer_denoise/train.py deleted file mode 100755 index dba253163d..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/train.py +++ /dev/null @@ -1,1378 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default) - - ctc loss - - attention decoder loss - - cr-ctc loss (should use half the max-duration compared to regular ctc) -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import DenoisingAsrModel -from optim import Sched3, TransformedAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zapformer import Zapformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def get_adjusted_lr_batches(params: AttributeDict) -> float: - # returns an adjusted form of the "lr_batches" parameter used to set the learning - # rate in the Sched3 scheduler. - # We want the final LR to be based on the geometric mean of "how much data we - # have seen" and "how many batches we have seen". - # an easier way to look at it is this: the formula for learning rate depends - # on (cur_batch / lr_batches). if we write this as: - # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches - # then the numerator is a geometric mean of "how many batches we have seen" - # and "how much data we have seen". We can achieve this by setting - # lr_batches = params.lr_batches * (duration_ratio ** -0.5). - duration_ratio = (params.max_duration * params.world_size) / params.ref_duration - lr_batches = params.lr_batches * (duration_ratio ** -0.5) - logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") - return lr_batches - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def lookup(params: AttributeDict, name: str): - """ - Interprets numerical arguments in `params` by taking into account base-dim; - also parses comma-separated lists of integers, turning them into tuples. - If a particular attribute ending in "dim" is not present we look up - the same name but ending in "factor", and multiply the elements by base_dim. - """ - try: - attr = getattr(params, name) - try: - attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints - if len(attr) == 1: - attr = attr[0] - except: - pass # leave attr as it is, e.g. a string. - return attr - except AttributeError as e: - if name[-3:] != "dim": - raise e - try: - attr = getattr(params, name[:-3] + "multiple") - if isinstance(attr, str): - attr = tuple(map(int, attr.split(","))) # tuple of ints - base_dim = params.base_dim - attr = tuple([i * base_dim for i in attr]) - if len(attr) == 1: - attr = attr[0] - else: # assume int. - assert isinstance(attr, (int, float)), (name, attr) - attr = attr * params.base_dim - return attr - except AttributeError as e: - raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") - - - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="8,8,8", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,1,1", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--base-dim", - type=int, - default=64, - help="Dimension that, via multiples, defines the dimensions of the model." - ) - - parser.add_argument( - "--embed-multiple", - type=int, - default=6, - help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", - ) - - parser.add_argument( - "--text-embed-dim", - type=int, - default=8, - help="Dim of text embeddings.", - ) - - parser.add_argument( - "--speech-loss-scale", - type=float, - default=1.0, - help="Loss scale on the speech part of the loss", - ) - - parser.add_argument( - "--time-embed-multiple", - type=int, - default=4, - help="Multiply by base-dim to determine dimension of time embedding." - ) - - - parser.add_argument( - "--feedforward-multiple", - type=str, - default="3", - help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,8,4", - help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-multiple", - type=str, - default="6,6,6", - help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--debug-interval", - type=int, - default=10, - help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for - finding parts of the network that are diverging or not well trained. - """ - ) - - parser.add_argument( - "--dump-debug-interval", - type=int, - default=0, - help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they - are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. - Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file - opened, seeked-in and closed for each scalar that is written). - """ - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zapformer_denoise/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=17500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--use-bf16", - type=str2bool, - default=False, - help="Whether to use bf16 in AMP.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - warm_step: The warmup period that dictates the decay of the - scale on pruned loss (for transducer) and the reconstruction and prediction - losses. Expressed in terms of the "adjusted batch count", i.e. the - normalized batch count after adjusting for changes in batch size. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - # parameters for attention-decoder - "ignore_id": -1, - "label_smoothing": 0.1, - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_speech_embed(params: AttributeDict) -> nn.Module: - # speech_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - speech_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=lookup(params, "embed_dim"), - dropout=0.0, - ) - return speech_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zapformer( - input_dim=lookup(params, "embed_dim"), - time_embed_dim=lookup(params, "time_embed_dim"), - downsampling_factor=lookup(params, "downsampling_factor"), - num_encoder_layers=lookup(params, "num_encoder_layers"), - encoder_dim=lookup(params, "encoder_dim"), - query_head_dim=lookup(params, "query_head_dim"), - pos_head_dim=lookup(params, "pos_head_dim"), - value_head_dim=lookup(params, "value_head_dim"), - pos_dim=params.pos_dim, - num_heads=lookup(params, "num_heads"), - feedforward_multiple=lookup(params, "feedforward_multiple"), - cnn_module_kernel=lookup(params, "cnn_module_kernel"), - dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=lookup(params, "decoder_dim"), - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - - - -def get_model(params: AttributeDict) -> nn.Module: - - #speech_embed = get_speech_embed(params) - encoder = get_encoder_model(params) - - - model = DenoisingAsrModel( - #speech_embed=speech_embed, - encoder=encoder, - encoder_dim=lookup(params, "embed_dim"), # see embed-multiple - text_embed_dim=lookup(params, "text_embed_dim"), - vocab_size=params.vocab_size, - time_embed_dim=lookup(params, "time_embed_dim") # see time-embed-multiple - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - x = batch["inputs"] - # at entry, feature is (N, T, C) - assert x.ndim == 3 - x = x.to(device) - - supervisions = batch["supervisions"] - x_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) # list of lists. - y_lens = [ len(sent) for sent in y ] - max_y_len = max(y_lens) - y = [ sent + [ 0 ] * (max_y_len - len(sent)) for sent in y ] - y = torch.tensor(y).to(device) - y_lens = torch.tensor(y_lens).to(device) - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - nframes = (x_lens // params.subsampling_factor).sum().item() - info["frames"] = nframes - - with torch.set_grad_enabled(is_training): - speech_loss, text_loss = model(x, x_lens, y, y_lens) - # (speech_loss - 2 * nframes).relu() is to prevent it from completely ignoring the speech loss. - loss = params.speech_loss_scale * speech_loss + (speech_loss - (2.0 * nframes)).relu() + text_loss - - - assert loss.requires_grad == is_training - - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["text_loss"] = text_loss.detach().cpu().item() - info["speech_loss"] = speech_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - if params.debug_interval > 0: - optimizer.write_debug_info(summary_writer=tb_writer) - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.amp.autocast('cuda', - enabled=params.use_autocast, dtype=params.dtype - ): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: - logging.info(f"Caught exception: {e}.") - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if params.use_autocast: - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - if not params.inf_check: - register_inf_check_hooks(model) - logging.warning(f"Grad scale is small: {cur_grad_scale}") - - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or - batch_idx % 100 == 0 and cur_grad_scale < 8.0 or - batch_idx % 400 == 0 and cur_grad_scale < 32.0): - scaler.update(cur_grad_scale * 2.0) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_autocast: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: - optimizer.write_debug_info(summary_writer=tb_writer) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - params.vocab_size = sp.get_piece_size() - - if params.use_bf16: # amp + bf16 - assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" - assert not params.use_fp16, "You can only use either fp16 or bf16" - params.dtype = torch.bfloat16 - params.use_autocast = True - elif params.use_fp16: # amp + fp16 - params.dtype = torch.float16 - params.use_autocast = True - else: # fp32 - params.dtype = torch.float32 - params.use_autocast = False - - logging.info(f"Using dtype={params.dtype}") - logging.info(f"Use AMP={params.use_autocast}") - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = TransformedAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - debug_interval=params.debug_interval, - ) - - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - librispeech = LibriSpeechAsrDataModule(args) - - if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - - # previously we used the following code to load all training cuts, - # strictly speaking, shuffled training cuts should be used instead, - # but we leave the code here to demonstrate that there is an option - # like this to combine multiple cutsets - - # train_cuts = librispeech.train_clean_100_cuts() - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and False: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - d = diagnostic.print_diagnostics() - filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" - torch.save(d, filename) - logging.info(f"Saved detailed diagnostics to {filename}") - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - spec_augment: Optional[nn.Module] = None, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.amp.autocast('cuda', - enabled=params.use_autocast, dtype=params.dtype - ): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - spec_augment=spec_augment, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zapformer_denoise/zapformer.py b/egs/librispeech/ASR/zapformer_denoise/zapformer.py deleted file mode 100644 index e9839ff451..0000000000 --- a/egs/librispeech/ASR/zapformer_denoise/zapformer.py +++ /dev/null @@ -1,1344 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from scaling import ( - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - OrthogonalLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - ScaleLimiter, - ActivationDropoutAndLinear, - ExpNorm, - ChunkCausalDepthwiseConv1d, - Dropout2, - FloatLike, - ScheduledFloat, - Whiten, - convert_num_channels, - limit_param_value, - penalize_abs_values_gt, - softmax, -) -from torch import Tensor, nn - - -class Zapformer(nn.Module): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - time_embed_dim: an integer giving the dimension of the time embeddings provided - to the network. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - """ - def __init__( - self, - input_dim: int, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - time_embed_dim: int = 256, - num_encoder_layers: Union[int, Tuple[int]] = 4, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_multiple: Union[int, Tuple[int]] = 4, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - ) -> None: - super().__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_multiple = _to_tuple(feedforward_multiple) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - - # each one will be ZapformerEncoder or OrthogonalDownsample or OrthogonalUpsample - encoders = [] - - num_encoders = len(downsampling_factor) - cur_downsample = 1 - - # caution: some changes we made for this break the streaming, later we'll try to fix this. - encoders_downsampling_factors = [ ] - - # make it so large the limit is never reached. - max_proj_dim = max(downsampling_factor) * max(encoder_dim) - - def set_downsample_factor(cur_downsample, ds): - while cur_downsample < ds: - # need to downsample - encoders.append(OrthogonalDownsample(channels=input_dim * cur_downsample, - proj_dim=min(2 * input_dim * cur_downsample, max_proj_dim))) - cur_downsample *= 2 - while cur_downsample > ds: - encoders.append(OrthogonalUpsample(channels=input_dim * cur_downsample, - proj_dim=min(input_dim * cur_downsample, max_proj_dim))) - cur_downsample //= 2 - return cur_downsample - - for i in range(num_encoders): - cur_downsample = set_downsample_factor(cur_downsample, downsampling_factor[i]) - - encoder_layer = ZapformerEncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_multiple=feedforward_multiple[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = ZapformerEncoder( - encoder_layer, - num_encoder_layers[i], - dim=cur_downsample*input_dim, - pos_dim=pos_dim, - time_embed_dim=time_embed_dim, - ) - encoder.encoder_index = i - encoders.append(encoder) - - cur_downsample = set_downsample_factor(cur_downsample, 1) - - self.encoders = nn.ModuleList(encoders) - - - def forward( - self, - x: Tensor, - time_embed: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, encoder_dim). - time_embed: - The timestep-embedding tensor. Its shape is (batch_size, time_embed_dim) - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return embeddings with the same shape as x: (seq_len, batch_size, encoder_dim) - """ - orig_seq_len = x.shape[0] - - def truncate(x, downsampling_factor): - max_len = (orig_seq_len + downsampling_factor - 1) // downsampling_factor - return x[:max_len] if x.shape[0] > max_len else x - - - for module in self.encoders: - if isinstance(module, ZapformerEncoder): - i = module.encoder_index # was set in this class's __init__ function. - ds = self.downsampling_factor[i] - x = truncate(x, ds) - x = module( - x, - time_embed, - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - ) - else: - x = module(x) - - x = x[:orig_seq_len] - return x - - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - -class ZapformerEncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_multiple: determines the hidden dimension of the feedforward module - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module (default=31). - - Examples:: - >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_multiple: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - randomize_scale: FloatLike = ScheduledFloat((0.0, 1.0), (20000.0, 0.75)), - ) -> None: - super(ZapformerEncoderLayer, self).__init__() - self.embed_dim = embed_dim - self.name = None # will be set from training loop - - self.randomize_scale = copy.deepcopy(randomize_scale) - # self.bypass implements layer skipping as well as learnable scale on a residual term; see its default values. - self.residual = ResidualModule( - embed_dim, - ) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1, self.self_attn2 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(2) ] - - feedforward_dim = embed_dim * feedforward_multiple - self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) - - self.conv_module1, self.conv_module2 = [ ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - for _ in range(2) ] - - self.scale_limiter = ScaleLimiter(max_var=2.0) - - self.norm = ExpNorm(embed_dim) - - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - src = src + self.self_attn1(src, attn_weights) - - src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) - - src = src + self.feed_forward2(src) - - src = src + self.self_attn2(src, attn_weights) - - src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask) - - src = src + self.feed_forward3(src) - - src = self.residual(src_orig, src) - - src = self.scale_limiter(src) - - src = self.norm(src) - - return src - - -class ZapformerEncoder(nn.Module): - r"""ZapformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ZapformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - dim: the dimension of the input and output (layer dim may be less than this). - pos_dim: the dimension for the relative positional encoding -dropout: - - Examples:: - >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = ZapformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - - - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dim: int, - pos_dim: int, - time_embed_dim: int, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.0, length_factor=1.0 - ) - self.name = None - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - self.residual = ResidualModule(encoder_layer.embed_dim) - - self.time_embed = ScaledLinear(time_embed_dim, encoder_layer.embed_dim, initial_scale=0.1) - - #bypass_dim = dim - encoder_layer.embed_dim - self.copy_bypass = Identity() - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(3.0), - prob=(1, 1), - grad_scale=0.025, - ) - - - - def forward( - self, - src: Tensor, - time_embed: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), - but embed_dim is allowed to exceed the modules' embed_dim; we will bypass - any extra dimensions. - time_embed: the time embedding, shape: (batch_size, seq_len) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - - num_channels = src.shape[-1] - layer_dim = self.layers[0].embed_dim - if num_channels > layer_dim: - src, bypass = src[..., :layer_dim], src[..., layer_dim:] - - - src_orig = src - src = src + self.time_embed(time_embed) - for i, mod in enumerate(self.layers): - src = mod( - src, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - # randomize_factor can be viewed as a simple version of an - # importance-sampling factor. - - src = self.residual(src_orig, src) - src = self.whiten(src) - - if num_channels > layer_dim: - bypass = self.copy_bypass(bypass) - src = torch.cat((src, bypass), dim=-1) - - return src - - -class ResidualModule(nn.Module): - """ - An nn.Module that implements a learnable residual scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - - def __init__( - self, - embed_dim: int, - function_scale_min: FloatLike = 0.1, - ): - super().__init__() - self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.function_scale_min = copy.deepcopy(function_scale_min) - - - def _get_scales(self): - function_scale = self.function_scale - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - function_scale = limit_param_value( - function_scale, min=float(self.function_scale_min), max=1.0, - ) - residual_scale = 1.0 - function_scale - return residual_scale, function_scale - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - residual_scale, function_scale = self._get_scales() - return residual_scale * src_orig + function_scale * src - - - -class OrthogonalDownsample(torch.nn.Module): - """ - Does downsampling with an orthogonal matrix, by a factor of two. Projection is initialized - in a special way and enforced to be orthogonal. - - Args: - channels: the number of input channels; the num output channels will be twice this - proj_dim: the number of channels, after combining 2 frames by interpolating their channels - as [ a b a b, .. ] that will actually be projected; the rest are just copied. - proj_dim=2 * channels would mean all channels are projected in a learned way - causal: True for causal systems, only affects error messages as requires even - input num frames. - """ - def __init__( - self, channels: int, proj_dim: int, causal: bool = False, - ): - super().__init__() - assert proj_dim <= channels * 2 - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 - self.causal = causal - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - - if seq_len % 2 == 1: - if torch.jit.is_tracing(): - assert ( - not self.causal - ), f"pad should be zero for exporting streaming models. Given {pad}" - src = torch.cat((src, src[-1:]), dim=0) - seq_len += 1 - - # the following will place each 2 frames of a particular channel right after - # each other as if they were two different channels. - src = torch.stack((src[0::2], src[1::2]), dim=-1) - src = src.reshape(seq_len // 2, batch_size, in_channels * 2) - proj_channels = self.proj.weight.shape[0] - if proj_channels < in_channels * 2: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - return src - -class OrthogonalUpsample(torch.nn.Module): - """ - A very simple form of upsampling with an orthogonal matrix. - - proj_dim: the number of channels that will actually be projected; the rest are just copied. - proj_dim=channels would mean all channels are projected in a learned way - - """ - def __init__(self, channels: int, proj_dim: int): - super().__init__() - assert proj_dim <= channels - # gradually make smaller and then turn off the non-orthognality penalty. - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, - penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 - - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*2), batch_size, num_channels // 2) - """ - proj_channels = self.proj.weight.shape[0] - (seq_len, batch_size, in_channels) = src.shape - - if proj_channels < in_channels: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - - src = torch.stack((src[..., 0::2], src[..., 1::2]), - dim=1) # (seq_len, 2, batch_size, in_channels // 2) - src = src.reshape(seq_len * 2, batch_size, in_channels // 2) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the Fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - - def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0, embed_dim - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0, length_factor - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.name = None # will be overwritten in training code; for diagnostics. - - self.attn_score_limit = ScheduledFloat((0.0, 5.0), (5000.0, 20.0)) - self.attn_score_penalty_prob = ScheduledFloat((0.0, 1.0), (5000.0, 1.0), (5001.0, 0.1)) - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, - bias=True, initial_scale=0.125 * query_head_dim**-0.25 - ) - - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnostics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, ( - p.shape[-1], - num_heads, - pos_head_dim, - ) - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(k) # does nothing in the forward pass. [this may not really be needed due to the orthogonality constraint.] - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - if True: - # position scores. - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < float(self.attn_score_penalty_prob): - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=float(self.attn_score_limit), penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, - bias=True, out_groups=num_heads) - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 - ) - - f = max(1.0, embed_dim / (num_heads * value_head_dim)) - # the whitening metric cannot be less than f because of the rank imposed - # by the bottleneck. the final whitening limit will be (2.0*3.0) times f, - # i.e. 6 times greater than the mathematical smallest value it can have. - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(f * 2.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zapformer model.""" - - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - # try to get in the useful range of the activation function, i.e. not too small. - self.in_proj = ScaledLinear(embed_dim, feedforward_dim) - # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed - # to the TransformedAdam optimizer. - self.in_proj.weight_min_rms = 0.02 - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwashL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.5, - ) - - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zapformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) - if causal - else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwashR", - dropout_p=0.0, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if ( - not torch.jit.is_scripting() - and not torch.jit.is_tracing() - and chunk_size >= 0 - ): - # Not support exporting a model for simulated streaming decoding - assert ( - self.causal - ), "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - seq_len = 20 - # Just make sure the forward pass runs. - - input_dim = 50 - time_embed_dim = 64 - - c = Zapformer( - input_dim=input_dim, - encoder_dim=(64, 96), - time_embed_dim=time_embed_dim, - num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,), - ) - - batch_size = 6 # make it even, as PredictLoss requires even batch size. - seq_len = 21 - # Just make sure the forward pass runs. - time_embed = torch.randn(batch_size, time_embed_dim) - - f = c( - torch.randn(seq_len, batch_size, input_dim), - time_embed, - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f.sum().backward() - c.eval() - f = c( - torch.randn(seq_len, batch_size, input_dim), - time_embed, - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) From 9fb0cad612fdd5ccc7695d0ca7e0941cc68e1dd4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Apr 2026 14:53:08 +0800 Subject: [PATCH 1007/1191] Remove zapformer2 directory --- .../ASR/zapformer2/asr_datamodule.py | 454 ---- .../ASR/zapformer2/attention_decoder.py | 1 - egs/librispeech/ASR/zapformer2/beam_search.py | 1 - egs/librispeech/ASR/zapformer2/ctc_decode.py | 1 - egs/librispeech/ASR/zapformer2/decode.py | 1089 --------- .../ASR/zapformer2/decode_gigaspeech.py | 1 - .../ASR/zapformer2/decode_stream.py | 1 - egs/librispeech/ASR/zapformer2/decoder.py | 1 - .../ASR/zapformer2/encoder_interface.py | 1 - .../ASR/zapformer2/export-onnx-ctc.py | 1 - .../zapformer2/export-onnx-streaming-ctc.py | 1 - .../ASR/zapformer2/export-onnx-streaming.py | 1 - egs/librispeech/ASR/zapformer2/export-onnx.py | 1 - egs/librispeech/ASR/zapformer2/export.py | 1 - egs/librispeech/ASR/zapformer2/finetune.py | 1 - .../ASR/zapformer2/generate_averaged_model.py | 1 - .../ASR/zapformer2/jit_pretrained.py | 1 - .../ASR/zapformer2/jit_pretrained_ctc.py | 1 - .../zapformer2/jit_pretrained_streaming.py | 1 - egs/librispeech/ASR/zapformer2/joiner.py | 1 - .../ASR/zapformer2/label_smoothing.py | 1 - egs/librispeech/ASR/zapformer2/model.py | 630 ----- egs/librispeech/ASR/zapformer2/my_profile.py | 1 - egs/librispeech/ASR/zapformer2/onnx_check.py | 1 - egs/librispeech/ASR/zapformer2/onnx_decode.py | 1 - .../onnx_pretrained-streaming-ctc.py | 1 - .../zapformer2/onnx_pretrained-streaming.py | 1 - .../ASR/zapformer2/onnx_pretrained.py | 1 - .../ASR/zapformer2/onnx_pretrained_ctc.py | 1 - .../ASR/zapformer2/onnx_pretrained_ctc_H.py | 1 - .../ASR/zapformer2/onnx_pretrained_ctc_HL.py | 1 - .../ASR/zapformer2/onnx_pretrained_ctc_HLG.py | 1 - .../onnx_pretrained_ctc_HLG_streaming.py | 1 - egs/librispeech/ASR/zapformer2/optim.py | 1 - egs/librispeech/ASR/zapformer2/pretrained.py | 1 - .../ASR/zapformer2/pretrained_ctc.py | 1 - .../relative_position_attention_bwd_k_2.py | 321 --- .../relative_position_attention_bwd_pos_2.py | 321 --- .../relative_position_attention_bwd_q_2.py | 332 --- .../relative_position_attention_fwd_2.py | 302 --- ...ive_position_attention_module_optimized.py | 118 - egs/librispeech/ASR/zapformer2/scaling.py | 1 - .../ASR/zapformer2/scaling_converter.py | 1 - .../ASR/zapformer2/speech_recognition.py | 229 -- .../ASR/zapformer2/streaming_beam_search.py | 1 - .../ASR/zapformer2/streaming_decode.py | 1 - egs/librispeech/ASR/zapformer2/subsampling.py | 1 - .../ASR/zapformer2/test_scaling.py | 1 - .../ASR/zapformer2/test_subsampling.py | 1 - egs/librispeech/ASR/zapformer2/train.py | 1678 ------------- egs/librispeech/ASR/zapformer2/zipformer.py | 2066 ----------------- 51 files changed, 7580 deletions(-) delete mode 100755 egs/librispeech/ASR/zapformer2/asr_datamodule.py delete mode 120000 egs/librispeech/ASR/zapformer2/attention_decoder.py delete mode 120000 egs/librispeech/ASR/zapformer2/beam_search.py delete mode 120000 egs/librispeech/ASR/zapformer2/ctc_decode.py delete mode 100755 egs/librispeech/ASR/zapformer2/decode.py delete mode 120000 egs/librispeech/ASR/zapformer2/decode_gigaspeech.py delete mode 120000 egs/librispeech/ASR/zapformer2/decode_stream.py delete mode 120000 egs/librispeech/ASR/zapformer2/decoder.py delete mode 120000 egs/librispeech/ASR/zapformer2/encoder_interface.py delete mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer2/export-onnx-streaming.py delete mode 120000 egs/librispeech/ASR/zapformer2/export-onnx.py delete mode 120000 egs/librispeech/ASR/zapformer2/export.py delete mode 120000 egs/librispeech/ASR/zapformer2/finetune.py delete mode 120000 egs/librispeech/ASR/zapformer2/generate_averaged_model.py delete mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained.py delete mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py delete mode 120000 egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py delete mode 120000 egs/librispeech/ASR/zapformer2/joiner.py delete mode 120000 egs/librispeech/ASR/zapformer2/label_smoothing.py delete mode 100755 egs/librispeech/ASR/zapformer2/model.py delete mode 120000 egs/librispeech/ASR/zapformer2/my_profile.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_check.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_decode.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py delete mode 120000 egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py delete mode 120000 egs/librispeech/ASR/zapformer2/optim.py delete mode 120000 egs/librispeech/ASR/zapformer2/pretrained.py delete mode 120000 egs/librispeech/ASR/zapformer2/pretrained_ctc.py delete mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py delete mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py delete mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py delete mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py delete mode 100755 egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py delete mode 120000 egs/librispeech/ASR/zapformer2/scaling.py delete mode 120000 egs/librispeech/ASR/zapformer2/scaling_converter.py delete mode 100755 egs/librispeech/ASR/zapformer2/speech_recognition.py delete mode 120000 egs/librispeech/ASR/zapformer2/streaming_beam_search.py delete mode 120000 egs/librispeech/ASR/zapformer2/streaming_decode.py delete mode 120000 egs/librispeech/ASR/zapformer2/subsampling.py delete mode 120000 egs/librispeech/ASR/zapformer2/test_scaling.py delete mode 120000 egs/librispeech/ASR/zapformer2/test_subsampling.py delete mode 100755 egs/librispeech/ASR/zapformer2/train.py delete mode 100644 egs/librispeech/ASR/zapformer2/zipformer.py diff --git a/egs/librispeech/ASR/zapformer2/asr_datamodule.py b/egs/librispeech/ASR/zapformer2/asr_datamodule.py deleted file mode 100755 index 4db6e101fb..0000000000 --- a/egs/librispeech/ASR/zapformer2/asr_datamodule.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, -) -# This K2SpeechRecognitionDataset is a modified version of one from -# lhotse.dataset, modified to, in training mode, to return a batch that has 3 -# different copies of the same data with the last two having different Musan -# augmentations and the first having none; and also include the key "num_copies" -# in the batch which would be 1 for the validation data (no Musan) and 3 for the -# training data with musan. -from speech_recognition import K2SpeechRecognitionDataset -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="""Used only when --mini-libri is False.When enabled, - use 960h LibriSpeech. Otherwise, use 100h subset.""", - ) - group.add_argument( - "--mini-libri", - type=str2bool, - default=False, - help="True for mini librispeech", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=[], - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_clean_5_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get train-clean-5 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" - ) - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_2_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get dev-clean-2 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) - - @lru_cache() - def gigaspeech_subset_small_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") - - @lru_cache() - def gigaspeech_dev_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") - - @lru_cache() - def gigaspeech_test_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zapformer2/attention_decoder.py b/egs/librispeech/ASR/zapformer2/attention_decoder.py deleted file mode 120000 index 830180a0cd..0000000000 --- a/egs/librispeech/ASR/zapformer2/attention_decoder.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/beam_search.py b/egs/librispeech/ASR/zapformer2/beam_search.py deleted file mode 120000 index 8554e44ccf..0000000000 --- a/egs/librispeech/ASR/zapformer2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/ctc_decode.py b/egs/librispeech/ASR/zapformer2/ctc_decode.py deleted file mode 120000 index a78e5c1df0..0000000000 --- a/egs/librispeech/ASR/zapformer2/ctc_decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode.py b/egs/librispeech/ASR/zapformer2/decode.py deleted file mode 100755 index 221f01297b..0000000000 --- a/egs/librispeech/ASR/zapformer2/decode.py +++ /dev/null @@ -1,1089 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--skip-scoring", - type=str2bool, - default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) - prefix = f"{params.decoding_method}" - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - prefix += f"_beam-{params.beam}" - prefix += f"_max-contexts-{params.max_contexts}" - prefix += f"_max-states-{params.max_states}" - if "nbest" in params.decoding_method: - prefix += f"_num-paths-{params.num_paths}" - prefix += f"_nbest-scale-{params.nbest_scale}" - if "LG" in params.decoding_method: - prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - - return {prefix: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix += f"_beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"_context-score-{params.context_score}" - return {prefix: hyps} - else: - prefix += f"_beam-size-{params.beam_size}" - return {prefix: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - """ - Save text produced by ASR. - """ - for key, results in results_dict.items(): - - recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - - results = sorted(results) - store_transcripts(filename=recogs_filename, texts=results) - - logging.info(f"The transcripts are stored in {recogs_filename}") - - -def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], -): - """ - Save WER and per-utterance word alignments. - """ - test_set_wers = dict() - for key, results in results_dict.items(): - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w", encoding="utf8") as fd: - wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info(f"Wrote detailed error stats to {errs_filename}") - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - - wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - - with open(wer_filename, "w", encoding="utf8") as fd: - print("settings\tWER", file=fd) - for key, val in test_set_wers: - print(f"{key}\t{val}", file=fd) - - s = f"\nFor {test_set_name}, WER of different settings are:\n" - note = f"\tbest for {test_set_name}" - for key, val in test_set_wers: - s += f"{key}\t{val}{note}\n" - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - # enable AudioCache - set_caching_enabled(True) # lhotse - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}_avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"_chunk-{params.chunk_size}" - params.suffix += f"_left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"_beam-{params.beam}" - params.suffix += f"_max-contexts-{params.max_contexts}" - params.suffix += f"_max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"_nbest-scale-{params.nbest_scale}" - params.suffix += f"_num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"_context-{params.context_size}" - params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "_use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append((sp.encode(line.strip()), 0.0)) - context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - dev_clean_cuts = librispeech.dev_clean_cuts() - dev_other_cuts = librispeech.dev_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) - dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) - - test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] - test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_asr_output( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if not params.skip_scoring: - save_wer_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py b/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py deleted file mode 120000 index 63b0ef617b..0000000000 --- a/egs/librispeech/ASR/zapformer2/decode_gigaspeech.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decode_gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decode_stream.py b/egs/librispeech/ASR/zapformer2/decode_stream.py deleted file mode 120000 index 4e59d04a12..0000000000 --- a/egs/librispeech/ASR/zapformer2/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/decoder.py b/egs/librispeech/ASR/zapformer2/decoder.py deleted file mode 120000 index cab465d2b9..0000000000 --- a/egs/librispeech/ASR/zapformer2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/encoder_interface.py b/egs/librispeech/ASR/zapformer2/encoder_interface.py deleted file mode 120000 index aa5d0217a8..0000000000 --- a/egs/librispeech/ASR/zapformer2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py deleted file mode 120000 index dc14e93e75..0000000000 --- a/egs/librispeech/ASR/zapformer2/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py deleted file mode 120000 index 3baa2b673c..0000000000 --- a/egs/librispeech/ASR/zapformer2/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py deleted file mode 120000 index d18cb9a9a1..0000000000 --- a/egs/librispeech/ASR/zapformer2/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export-onnx.py b/egs/librispeech/ASR/zapformer2/export-onnx.py deleted file mode 120000 index f343cf7027..0000000000 --- a/egs/librispeech/ASR/zapformer2/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/export.py b/egs/librispeech/ASR/zapformer2/export.py deleted file mode 120000 index 1a126ab695..0000000000 --- a/egs/librispeech/ASR/zapformer2/export.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/finetune.py b/egs/librispeech/ASR/zapformer2/finetune.py deleted file mode 120000 index 0e9e7989b9..0000000000 --- a/egs/librispeech/ASR/zapformer2/finetune.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/finetune.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/generate_averaged_model.py b/egs/librispeech/ASR/zapformer2/generate_averaged_model.py deleted file mode 120000 index b65513a058..0000000000 --- a/egs/librispeech/ASR/zapformer2/generate_averaged_model.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained.py b/egs/librispeech/ASR/zapformer2/jit_pretrained.py deleted file mode 120000 index 5d45825206..0000000000 --- a/egs/librispeech/ASR/zapformer2/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py deleted file mode 120000 index 43aeb684bf..0000000000 --- a/egs/librispeech/ASR/zapformer2/jit_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py deleted file mode 120000 index 8e5e6f9812..0000000000 --- a/egs/librispeech/ASR/zapformer2/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/joiner.py b/egs/librispeech/ASR/zapformer2/joiner.py deleted file mode 120000 index 444cb5f150..0000000000 --- a/egs/librispeech/ASR/zapformer2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/label_smoothing.py b/egs/librispeech/ASR/zapformer2/label_smoothing.py deleted file mode 120000 index 3690afff9d..0000000000 --- a/egs/librispeech/ASR/zapformer2/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/model.py b/egs/librispeech/ASR/zapformer2/model.py deleted file mode 100755 index 278e498032..0000000000 --- a/egs/librispeech/ASR/zapformer2/model.py +++ /dev/null @@ -1,630 +0,0 @@ -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from torch import Tensor -from encoder_interface import EncoderInterface -from scaling import ScaledLinear, convert_num_channels, PredictLoss -from icefall.utils import add_sos, make_pad_mask, time_warp - - -class AsrModel(nn.Module): - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - decoder: Optional[nn.Module] = None, - joiner: Optional[nn.Module] = None, - attention_decoder: Optional[nn.Module] = None, - encoder_dim: int = 384, - decoder_dim: int = 512, - vocab_size: int = 500, - use_transducer: bool = True, - use_ctc: bool = False, - use_attention_decoder: bool = False, - ): - """A joint CTC & Transducer ASR model. - - - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) - - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) - - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) - - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dim) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - It is used when use_transducer is True. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - It is used when use_transducer is True. - use_transducer: - Whether use transducer head. Default: True. - use_ctc: - Whether use CTC head. Default: False. - use_attention_decoder: - Whether use attention-decoder head. Default: False. - """ - super().__init__() - - assert ( - use_transducer or use_ctc - ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" - - assert isinstance(encoder, EncoderInterface), type(encoder) - - self.encoder_embed = encoder_embed - self.encoder = encoder - - self.predict_loss = PredictLoss(encoder_dim) - - self.use_transducer = use_transducer - if use_transducer: - # Modules for Transducer head - assert decoder is not None - assert hasattr(decoder, "blank_id") - assert joiner is not None - - - - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_scale=0.1, - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, vocab_size, initial_scale=0.1, - ) - - else: - assert decoder is None - assert joiner is None - - self.use_ctc = use_ctc - if use_ctc: - # Modules for CTC head - self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), - ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), - nn.LogSoftmax(dim=-1), - ) - - self.use_attention_decoder = use_attention_decoder - if use_attention_decoder: - self.attention_decoder = attention_decoder - else: - assert attention_decoder is None - - self.reconstruction_proj = ScaledLinear( - encoder_dim, 4 * encoder_embed.in_channels, initial_scale=0.1) - - - def forward_encoder( - self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, sd_prob: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute encoder outputs. - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - aux_loss_scale: - auxiliary-loss scale, for scaling cosine losses in the encoders. - sc_prob: - stochastic-depth probability: not a layer skipping probabilty but the probabibilty - of taking the output of a randomly chosen layer, instead of the last layer. - - - Returns: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - """ - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - specaug_mask = (x[..., 0] == x[..., 1]) # (N, T) - - x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - - src_key_padding_mask = make_pad_mask(x_lens) # (N, T) - specaug_mask = specaug_mask[:, ::2] - assert abs(specaug_mask.shape[1] - src_key_padding_mask.shape[1]) < 10 - specaug_mask = convert_num_channels(specaug_mask, src_key_padding_mask.shape[1]) # pad or truncate. (N, T) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, - aux_loss_scale=aux_loss_scale, - sd_prob=0.0) - - predict_loss = self.compute_predict_loss(encoder_out, src_key_padding_mask[:, ::2], specaug_mask[:, ::2]) - - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - - - return encoder_out, encoder_out_lens, predict_loss - - - def compute_predict_loss(self, - encoder_out: Tensor, - src_key_padding_mask: Optional[Tensor], - specaug_mask: Optional[Tensor]) -> Tensor: - if src_key_padding_mask is not None and specaug_mask is not None: - mask = torch.logical_and(src_key_padding_mask.t().logical_not(), specaug_mask.t().logical_not()) - elif src_key_padding_mask is not None: - mask = src_key_padding_mask.t().logical_not() - elif specaug_mask is not None: - mask = specaug_mask.t().logical_not() - else: - mask = None - return self.predict_loss(encoder_out, mask) - - - def forward_ctc( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - targets: torch.Tensor, - target_lengths: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - targets: - Target Tensor of shape (sum(target_lengths)). The targets are assumed - to be un-padded and concatenated within 1 dimension. - """ - # Compute CTC log-prob - ctc_output = self.ctc_output(encoder_out) # (N, T, C) - - - # the calls to .long() were added as a workaround for a problem with - # torch.nn.functional.ctc_loss() on newer torch versions. Previously - # instead of .long() we had .cpu(). This activates the use of CUDNN - # because it only uses CUDNN if integer inputs are in int32 and on CPU. - # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) - # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: - # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. - # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss - # seems to have a bug with "int32" integer arguments (it returns infinity), so we call - # .long() to use the torch implementation and avoid that bug. - ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.long(), - input_lengths=encoder_out_lens.long(), - target_lengths=target_lengths.long(), - reduction="sum", - ) - return ctc_loss - - def forward_cr_ctc( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - targets: torch.Tensor, - target_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute CTC loss, with consistency regularization loss if we are in training mode. - Args: - encoder_out: - Encoder output, of shape (2 * N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (2 * N,). - targets: - Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed - to be un-padded and concatenated within 1 dimension. - """ - # Compute CTC loss - ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) - ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) - targets=targets.long(), # the calls to .long() were added due to a bug in torch 2.5.1cuda12.1 on A20. - input_lengths=encoder_out_lens.long(), - target_lengths=target_lengths.long(), - reduction="sum", - ) - - # Compute consistency regularization loss - exchanged_targets = ctc_output.detach().chunk(2, dim=0) - exchanged_targets = torch.cat( - [exchanged_targets[1], exchanged_targets[0]], dim=0 - ) # exchange: [x1, x2] -> [x2, x1] - cr_loss = nn.functional.kl_div( - input=ctc_output, - target=exchanged_targets, - reduction="none", - log_target=True, - ) # (2 * N, T, C) - length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) - cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() - - return ctc_loss, cr_loss - - def forward_transducer( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - y: k2.RaggedTensor, - y_lens: torch.Tensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute Transducer loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - """ - # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = encoder_out_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.amp.autocast('cuda', enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.amp.autocast('cuda', enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return simple_loss, pruned_loss - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - spec_augment: Optional[nn.Module] = None, - supervision_segments: Optional[torch.Tensor] = None, - time_warp_factor: Optional[int] = 80, - num_copies: int = 1, - aux_loss_scale: float = 0.0, - sd_prob: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - spec_augment: - The SpecAugment instance, or similar/compatible object, that masks - log-mel features. - supervision_segments: - An int tensor of shape ``(S, 3)``. ``S`` is the number of - supervision segments that exist in ``features``. Used only for - time-warping, if num_copies > 1. - time_warp_factor: - Parameter for the time warping; larger values mean more warping. - Set to ``None``, or less than ``1``, to disable. - Used only if num_copies > 1, corresponds to training mode. - num_copies: - the number of copies of the same data that are in the batch, e.g. 1, 2 - or 3; affects CRCTC, spec-augment, etc. - aux_loss_scale: - auxiliary-loss scale, for scaling cosine losses in the encoders. - sc_prob: - stochastic-depth probability: not a layer skipping probabilty but the probabibilty - of taking the output of a randomly chosen layer, instead of the last layer. - - Returns: - Return the transducer losses, CTC loss, AED loss, - and consistency-regularization loss in form of - (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) - - device = x.device - - if num_copies > 1: - assert num_copies == 3 # for now. - # will do SpecAugment or similar. - assert spec_augment is not None and getattr(spec_augment, 'time_warp_factor', -1) < 0 - - (batch_size, seq_len, num_channels) = x.shape - B = batch_size // num_copies - x = x.reshape(num_copies, B, seq_len, num_channels) - - do_time_warp = True - if do_time_warp: - # Apply time warping. First append the copies on the channel - # dimension so all copies get the exact same time-warping. - x = x.permute(1, 2, 0, 3).reshape(B, seq_len, num_copies * num_channels) - - assert supervision_segments is not None - with torch.amp.autocast('cuda', enabled=False): - x = time_warp( - x.to(torch.float), - time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments[:B], - ) - x = x.reshape(B, seq_len, num_copies, num_channels) - x = x.permute(2, 0, 1, 3) # x: (num_copies, B, seq_len, num_channels) - - # x_no_specaug is several repeats of the 1st copy of the data, which - # is the one not augmented with Musan. But it does have time - # warping and mel warping. - x_no_specaug = x[0:1].repeat(num_copies - 1, 1, 1, 1).reshape( - B * (num_copies - 1), seq_len, num_channels) - - - # Independently apply frequency masking and time masking to all but the first - # copy of the data. - x = spec_augment(x[1:].reshape(-1, seq_len, num_channels)) - - x_lens = x_lens[:B*(num_copies-1)] - y = y[:B*(num_copies-1)] - else: - x_no_specaug = x - - - # Compute encoder outputs - encoder_out, encoder_out_lens, predict_loss = self.forward_encoder(x, x_lens, - aux_loss_scale=aux_loss_scale, - sd_prob=sd_prob) - - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - if self.use_transducer: - # Compute transducer loss - simple_loss, pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(device), - y_lens=y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) - else: - simple_loss = torch.empty(0) - pruned_loss = torch.empty(0) - - if self.use_ctc: - targets = y.values - if not self.training: - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) - cr_loss = torch.empty(0) - else: - ctc_loss, cr_loss = self.forward_cr_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) - else: - ctc_loss = torch.empty(0) - cr_loss = torch.empty(0) - - if self.use_attention_decoder: - attention_decoder_loss = self.attention_decoder.calc_att_loss( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ys=y.to(device), - ys_lens=y_lens.to(device), - ) - else: - attention_decoder_loss = torch.empty(0) - - reconstruction_loss = self.forward_reconstruction_loss(x_no_specaug, encoder_out, - encoder_out_lens) - - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss - - - def forward_reconstruction_loss(self, - log_mels: Tensor, - encoder_out: Tensor, - encoder_out_lens: Tensor): - """ - Compute and return reconstruction loss, a mixed l1/l2 loss on the input features. If - use_cr_ctc then we swap the first and second halves of the batch. - - Args: - log_mels: log-mel features of shape (batch_size, T, num_mels) - encoder_out: embeddings of shape (batch_size, T_embed, encoder_dim) - """ - batch_size = log_mels.shape[0] - num_mels = log_mels.shape[2] - - - def gauss_norm(x): - # normalize by gaussianizing on each dimension - values, indexes = x.sort(dim=1) # sort on seq dim - N = max(2, x.shape[1]) - norm_rank = torch.linspace(-1 + 1. / N, 1. - 1. / N, x.shape[1], device=x.device, dtype=torch.float) - norm_rank = torch.special.erfinv(norm_rank) # maps to Gaussian-distributed data - norm_rank = norm_rank.reshape(1, -1, 1) - norm_rank = norm_rank.repeat(x.shape[0], 1, x.shape[2]) - x_norm = torch.empty_like(x) - x_norm.scatter_(dim=1, index=indexes, src=norm_rank) - return x_norm - - log_mels = gauss_norm(log_mels) - - pred_mels = self.reconstruction_proj(encoder_out) # (batch_size, T_embed, 4 * num_mels) - T_embed = pred_mels.shape[1] - pred_mels = pred_mels.reshape(batch_size, T_embed * 4, num_mels) - - excess_frames = log_mels.shape[1] - pred_mels.shape[1] - assert 4 < excess_frames < 10 # should be around 7 or 8 I believe. - - T = pred_mels.shape[1] - offset = 3 # i found excess_frames = 5 one time. - log_mels = log_mels[:, offset:offset+T] - - lens = encoder_out_lens * 4 - pad_mask = make_pad_mask(lens) # boolean Tensor with True for masked positions - assert pad_mask.shape == (batch_size, T) - pad_mask = (~pad_mask).to(torch.float).unsqueeze(-1) # 0.0 for masked position - # padd_mask: (batch_size, T, 1) - - - # use 1.0 for the beta; note, log-mels have a fairly large dynamic range so this mostly - # helps to down-weight the effect of very silent silences. - #loss = torch.nn.functional.smooth_l1_loss(log_mels * pad_mask, pred_mels * pad_mask, - # reduction='none', beta=1.0) - # this way of applying the padding mask is not really ideal in terms of normalization, - # it will cause us to under-normalize a bit. - diff = log_mels * pad_mask - pred_mels * pad_mask - - loss = (diff ** 2) - - # removing the masking logic since we now use the no-specaug reference sequence. - ## masking. if it's different from the next item on both the frequency dim - ## and the time dim, it means we are in neither a time masked nor a frequency masked - ## position. - #mask = torch.logical_and(log_mels != torch.roll(log_mels, 1, dims=2), - # log_mels != torch.roll(log_mels, 1, dims=1)) - #loss = loss * mask.to(loss.dtype) - - loss = loss.mean(dim=-1).sum() # sum over all frames, but mean over mel bins. - return loss diff --git a/egs/librispeech/ASR/zapformer2/my_profile.py b/egs/librispeech/ASR/zapformer2/my_profile.py deleted file mode 120000 index 76e48b756b..0000000000 --- a/egs/librispeech/ASR/zapformer2/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/my_profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_check.py b/egs/librispeech/ASR/zapformer2/onnx_check.py deleted file mode 120000 index 7293c70d46..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_decode.py b/egs/librispeech/ASR/zapformer2/onnx_decode.py deleted file mode 120000 index 9e3faa5e01..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py deleted file mode 120000 index f8abb9daa5..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py deleted file mode 120000 index 11b846322e..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained.py deleted file mode 120000 index a085def837..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py deleted file mode 120000 index 0c082a204f..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py deleted file mode 120000 index 68102c7374..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_H.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py deleted file mode 120000 index 8314b4efdf..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HL.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py deleted file mode 120000 index 7a637a1c01..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py deleted file mode 120000 index a5b04b3f8b..0000000000 --- a/egs/librispeech/ASR/zapformer2/onnx_pretrained_ctc_HLG_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/optim.py b/egs/librispeech/ASR/zapformer2/optim.py deleted file mode 120000 index 207eecfcda..0000000000 --- a/egs/librispeech/ASR/zapformer2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained.py b/egs/librispeech/ASR/zapformer2/pretrained.py deleted file mode 120000 index 70ad71ffc6..0000000000 --- a/egs/librispeech/ASR/zapformer2/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/pretrained_ctc.py b/egs/librispeech/ASR/zapformer2/pretrained_ctc.py deleted file mode 120000 index fb9bdf1fa2..0000000000 --- a/egs/librispeech/ASR/zapformer2/pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py deleted file mode 100755 index aa85d1fff7..0000000000 --- a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_k_2.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -import triton.language as tl -import triton -import torch - - -def get_autotune_config(): - configs = [] - configs.append( - triton.Config( - { - "BLOCK_M": 1, - "BLOCK_N": 32, - "BLOCK_C": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=2, - ) - ) - return configs - - -@triton.autotune( - configs=get_autotune_config()[-2:], - key=["seq_q", "seq_k", "channels", "max_seq_len"], -) -@triton.jit -def relative_position_attention_bwd_k_kernel( - # fmt: off - q_ptr, # (batches, head, seq_q, channel) - k_ptr, # (batches, head, seq_k, channel) - pos_ptr, # (head, 2*max_seq_len-1, channel) - scores_grad_ptr, # (batches, head, seq_q, seq_k) - B, H, seq_q, seq_k, channels, max_seq_len, # shape - stride_qb, stride_qh, stride_qs, stride_qc, # stride for q - stride_kb, stride_kh, stride_ks, stride_kc, # stride for k - stride_ph, stride_ps, stride_pc, # stride for pos - stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores - BLOCK_M: tl.constexpr, # block size in scores_grad - BLOCK_N: tl.constexpr, # block size in q - BLOCK_C: tl.constexpr, # block size for seq_q - GROUP_SIZE_M: tl.constexpr, # size for grouped block -): - # fmt: on - pid = tl.program_id(axis=0) - pid_bh = tl.program_id(axis=1) - - head = pid_bh % H - batch = pid_bh // H - - num_pid_m = tl.cdiv(seq_k, BLOCK_M) - num_pid_n = tl.cdiv(channels, BLOCK_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_in_group = pid % num_pid_in_group - pid_m = first_pid_m + (pid_in_group % group_size_m) - pid_n = pid_in_group // group_size_m - - tl.assume(BLOCK_M == 1) - - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - tl.assume(stride_qb > 0) - tl.assume(stride_qh > 0) - tl.assume(stride_qs > 0) - tl.assume(stride_qc > 0) - - tl.assume(stride_kb > 0) - tl.assume(stride_kh > 0) - tl.assume(stride_ks > 0) - tl.assume(stride_kc > 0) - - tl.assume(stride_ph > 0) - tl.assume(stride_ps > 0) - tl.assume(stride_pc > 0) - - tl.assume(stride_sb > 0) - tl.assume(stride_sh > 0) - tl.assume(stride_sq > 0) - tl.assume(stride_sk > 0) - - # (BLOCK_M,), for k, seq_k - offs_m = pid_m * BLOCK_M - - # (BLOCK_N,), for j, channel - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_n_mask = offs_n[:, None] < channels - - # (BLOCK_C,), for i, seq_q - offs_c = tl.arange(0, BLOCK_C) - - q_base = q_ptr + batch * stride_qb + head * stride_qh + offs_n[:, None] * stride_qc - k_base = k_ptr + batch * stride_kb + head * stride_kh - pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc - scores_grad_base = ( - scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sk - ) - - acc = tl.zeros((BLOCK_N,), dtype=tl.float32) - - for c in range(0, channels, BLOCK_C): - c_idx = c + offs_c - - # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) - scores_grad_mask = (offs_m < seq_k) & (c_idx[None, :] < seq_q) - - # (BLOCK_N, BLOCK_C), or (J, I) - q_mask = offs_n_mask & (c_idx[None, :] < seq_q) - - # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) - rel_idx = c_idx[None, :] - offs_m + max_seq_len - 1 - - # (BLOCK_M, BLOCK_N, BLOCK_C), or (K, J, I), or (BLOCK_N, BLOCK_C) - pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask - - scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sq - q_ptrs = q_base + c_idx[None, :] * stride_qs - - # (BLOCK_M, BLOCK_C), or (K, I), or (1, BLOCK_C) - scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) - - # (BLOCK_N, BLOCK_C), or (J, I) - q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) - - # (BLOCK_N, BLOCK_C), or (J, I) - pos_ptrs = pos_base + rel_idx * stride_ps - - pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) - - # scores_grad_chunk (1, BLOCK_C), or (K, I) - # q_chunk (BLOCK_N, BLOCK_C), or (J, I) - # pos_chunk (BLOCK_N, BLOCK_C), or (J, I) - qp = q_chunk * pos_chunk - - acc += tl.sum(scores_grad_chunk * qp, axis=1) - - k_ptrs = k_base + offs_m * stride_ks + offs_n * stride_kc - k_mask = (offs_m < seq_k) & (offs_n < channels) - tl.store(k_ptrs, acc, mask=k_mask) - - -def relative_position_attention_bwd_k(scores_grad, q, pos): - if not scores_grad.is_contiguous(): - scores_grad = scores_grad.contiguous() - - assert scores_grad.is_contiguous(), ( - scores_grad.shape, - scores_grad.stride(0), - scores_grad.stride(1), - scores_grad.stride(2), - scores_grad.stride(3), - ) - assert q.is_contiguous() - assert pos.is_contiguous() - - assert scores_grad.ndim == q.ndim == 4, (scores_grad.shape, q.shape) - - assert pos.ndim == 3, pos.shape - b, h, seq_q, seq_k = scores_grad.shape - - assert q.shape[0] == b, q.shape - assert q.shape[1] == h, q.shape - assert q.shape[2] == seq_q, q.shape - - c = q.shape[3] - - assert pos.shape[0] == h, pos.shape - pos.shape[2] == c, pos.shape - - max_seq_len = (pos.shape[1] + 1) // 2 - - assert scores_grad.device == q.device == pos.device, ( - scores_grad.device, - q.device, - pos.device, - ) - - k = torch.empty(b, h, seq_k, c, device=q.device) - - grid = lambda META: ( - triton.cdiv(seq_k, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), - b * h, - ) - - # fmt:off - relative_position_attention_bwd_k_kernel[grid]( - q, k, pos, scores_grad, - b, h, seq_q, seq_k, c, max_seq_len, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - pos.stride(0), pos.stride(1), pos.stride(2), - scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), - ) - # fmt: on - return k - - -def relative_position_attention_fwd_torch(q, k, pos): - # this function consume a lot of memory, may OOM - max_seq_len = (pos.shape[1] + 1) // 2 - seq_q = q.shape[2] - seq_k = k.shape[2] - - q = q.unsqueeze(3) - k = k.unsqueeze(2) - - i = torch.arange(seq_q, device=q.device).unsqueeze(1) - j = torch.arange(seq_k, device=q.device).unsqueeze(0) - rel = (i - j) + max_seq_len - 1 - rel = rel.clamp(0, pos.shape[1] - 1) - pos_indexed = pos[:, rel].unsqueeze(0) - - # q: (b, h, seq_q, 1, c) - # q: (b, h, 1, seq_k, c) - # pos: (1, h, seq_q, seq_k, c) - scores = (q * k * pos_indexed).sum(dim=-1) - return scores - - -configs = [] -configs.append( - triton.testing.Benchmark( - x_names=[ - "b", - "h", - "seq_q", - "seq_k", - "c", - ], # Argument names to use as an x-axis for the plot - x_vals=[ - (b, h, seq, seq, c) - for b in [1, 2, 3] - # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation - for h in [2, 4] - for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] - for c in [128, 256, 512] - ], - line_arg="provider", # Argument name whose value corresponds to a different line in the plot - line_vals=["triton"], - line_names=["Triton"], - styles=[("green", "-")], - ylabel="time (ms)", # Label name for the y-axis - plot_name="matmul-performance", - args=dict(), - ) -) - - -@triton.testing.perf_report(configs) -def benchmark(b, h, seq_q, seq_k, c, provider): - device = torch.device("cuda", 0) - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) - - quantiles = [0.5, 0.2, 0.8] - if provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: relative_position_attention_bwd_k(scores_grad, q, pos), - quantiles=quantiles, - ) - return ms, max_ms, min_ms - - -def test_benchmark(): - benchmark.run(show_plots=False, print_data=True) - - -def test_correctness(): - device = torch.device("cuda", 0) - b = 2 - h = 2 - seq_q = 250 - seq_k = 250 - c = 1025 - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - - q_copy = q.clone() - pos_copy = pos.clone() - - q.requires_grad_(True) - k.requires_grad_(True) - pos.requires_grad_(True) - - scores0 = relative_position_attention_fwd_torch(q, k, pos) - scores0.retain_grad() - - scale = torch.rand_like(scores0) - s0 = (scale * scores0).sum() - s0.backward() - print("score0.grad", scores0.grad.shape, scores0.grad.sum()) - print("k.grad", k.grad.shape, k.grad.sum()) - - scores_grad = scores0.grad.clone() - k_grad = relative_position_attention_bwd_k(scores_grad, q_copy, pos_copy) - - print(k_grad.shape, k_grad.sum()) - print((k.grad - k_grad).abs().max()) - - -def main(): - test_benchmark() - # test_correctness() - - -if __name__ == "__main__": - torch.manual_seed(20250812) - main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py deleted file mode 100755 index 93d1f09dc3..0000000000 --- a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_pos_2.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -import triton.language as tl -import triton -import torch - - -def get_autotune_config(): - configs = [] - configs.append( - triton.Config( - { - "BLOCK_M": 1, - "BLOCK_N": 16, - "BLOCK_C": 16, - "GROUP_SIZE_M": 4, - }, - num_stages=2, - num_warps=2, - ) - ) - return configs - - -@triton.autotune( - configs=get_autotune_config()[-2:], - key=["seq_q", "seq_k", "channels", "max_seq_len"], -) -@triton.jit -def relative_position_attention_bwd_pos_kernel( - # fmt: off - q_ptr, # (batches, head, seq_q, channel) - k_ptr, # (batches, head, seq_k, channel) - pos_ptr, # (head, 2*max_seq_len-1, channel) - scores_grad_ptr, # (batches, head, seq_q, seq_k) - B, H, seq_q, seq_k, channels, max_seq_len, # shape - stride_qb, stride_qh, stride_qs, stride_qc, # stride for q - stride_kb, stride_kh, stride_ks, stride_kc, # stride for k - stride_ph, stride_ps, stride_pc, # stride for pos - stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores - BLOCK_M: tl.constexpr, # block size in q - BLOCK_N: tl.constexpr, # block size in k - BLOCK_C: tl.constexpr, # block size for channel - GROUP_SIZE_M: tl.constexpr, # size for grouped block, not used -): - # fmt: on - pid = tl.program_id(axis=0) - pid_bh = tl.program_id(axis=1) - - head = pid_bh % H - batch = pid_bh // H - - num_pid_n = tl.cdiv(seq_k, BLOCK_N) - - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - tl.assume(BLOCK_M == 1) - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - tl.assume(stride_qb > 0) - tl.assume(stride_qh > 0) - tl.assume(stride_qs > 0) - tl.assume(stride_qc > 0) - - tl.assume(stride_kb > 0) - tl.assume(stride_kh > 0) - tl.assume(stride_ks > 0) - tl.assume(stride_kc > 0) - - tl.assume(stride_ph > 0) - tl.assume(stride_ps > 0) - tl.assume(stride_pc > 0) - - tl.assume(stride_sb > 0) - tl.assume(stride_sh > 0) - tl.assume(stride_sq > 0) - tl.assume(stride_sk > 0) - - offs_m = pid_m * BLOCK_M - - # (BLOCK_N,) - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # (BLOCK_C,) - offs_c = tl.arange(0, BLOCK_C) - - # (BLOCK_N, 1) - rel_idx = offs_m - offs_n[:, None] + max_seq_len - 1 - - q_base = q_ptr + batch * stride_qb + head * stride_qh - k_base = k_ptr + batch * stride_kb + head * stride_kh - pos_base = pos_ptr + head * stride_ph - - scores_grad_base = scores_grad_ptr + batch * stride_sb + head * stride_sh - scores_grad_ptrs = ( - scores_grad_base + offs_m * stride_sq + offs_n[:, None] * stride_sk - ) - - # (BLOCK_N, 1) - scores_grad_mask = (offs_m < seq_q) & (offs_n[:, None] < seq_k) - - # (BLOCK_N, 1) - scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) - - for c in range(0, channels, BLOCK_C): - c_idx = c + offs_c - - # (1, BLOCK_C) - q_mask = (offs_m < seq_q) & (c_idx[None, :] < channels) - - # (BLOCK_N, BLOCK_C), or (K, J) - k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) - - # (BLOCK_N, BLOCK_C) - pos_mask = ( - (rel_idx >= 0) - & (rel_idx < 2 * max_seq_len - 1) - & (c_idx[None, :] < channels) - ) - - q_ptrs = q_base + offs_m * stride_qs + c_idx[None, :] * stride_qc - k_ptrs = k_base + offs_n[:, None] * stride_ks + c_idx[None, :] * stride_kc - - # (1, BLOCK_C) - q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) - - # (BLOCK_N, BLOCK_C) - k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) - - # (BLOCK_N, BLOCK_C) - pos_ptrs = pos_base + rel_idx * stride_ps + c_idx[None, :] * stride_pc - - # q_chunk (1, BLOCK_C) - # k_chunk (BLOCK_N, BLOCK_C) - # scores_grad_chunk (BLOCK_N, 1) - # - # pos_chunk: (BLOCK_N, BLOCK_C) - qk = q_chunk * k_chunk - pos_chunk = scores_grad_chunk * qk - - tl.atomic_add(pos_ptrs, pos_chunk, mask=pos_mask) - - -def relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len): - if not scores_grad.is_contiguous(): - scores_grad = scores_grad.contiguous() - - assert scores_grad.is_contiguous(), ( - scores_grad.shape, - scores_grad.stride(0), - scores_grad.stride(1), - scores_grad.stride(2), - scores_grad.stride(3), - ) - - assert q.is_contiguous() - assert k.is_contiguous() - - assert scores_grad.ndim == q.ndim == k.ndim == 4, ( - scores_grad.shape, - q.shape, - k.shape, - ) - b, h, seq_q, seq_k = scores_grad.shape - c = q.shape[3] - - assert k.shape[0] == b, k.shape - assert k.shape[1] == h, k.shape - assert k.shape[2] == seq_k, k.shape - assert k.shape[3] == c, k.shape - - assert q.shape[0] == b, q.shape - assert q.shape[1] == h, q.shape - assert q.shape[2] == seq_q, q.shape - - assert scores_grad.device == q.device == k.device, ( - scores_grad.device, - q.device, - k.device, - ) - - pos = torch.zeros(h, 2 * max_seq_len - 1, c, device=q.device) - - grid = lambda META: ( - triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), - b * h, - ) - - # fmt:off - relative_position_attention_bwd_pos_kernel[grid]( - q, k, pos, scores_grad, - b, h, seq_q, seq_k, c, max_seq_len, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - pos.stride(0), pos.stride(1), pos.stride(2), - scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), - ) - # fmt: on - return pos - - -def relative_position_attention_fwd_torch(q, k, pos): - # this function consume a lot of memory, may OOM - max_seq_len = (pos.shape[1] + 1) // 2 - seq_q = q.shape[2] - seq_k = k.shape[2] - - q = q.unsqueeze(3) - k = k.unsqueeze(2) - - i = torch.arange(seq_q, device=q.device).unsqueeze(1) - j = torch.arange(seq_k, device=q.device).unsqueeze(0) - rel = (i - j) + max_seq_len - 1 - rel = rel.clamp(0, pos.shape[1] - 1) - pos_indexed = pos[:, rel].unsqueeze(0) - - # q: (b, h, seq_q, 1, c) - # q: (b, h, 1, seq_k, c) - # pos: (1, h, seq_q, seq_k, c) - scores = (q * k * pos_indexed).sum(dim=-1) - return scores - - -configs = [] -configs.append( - triton.testing.Benchmark( - x_names=[ - "b", - "h", - "seq_q", - "seq_k", - "c", - ], # Argument names to use as an x-axis for the plot - x_vals=[ - (b, h, seq, seq, c) - for b in [1, 2, 3] - # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation - for h in [2, 4] - for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] - for c in [128, 256, 512] - ], - line_arg="provider", # Argument name whose value corresponds to a different line in the plot - line_vals=["triton"], - line_names=["Triton"], - styles=[("green", "-")], - ylabel="time (ms)", # Label name for the y-axis - plot_name="matmul-performance", - args=dict(), - ) -) - - -@triton.testing.perf_report(configs) -def benchmark(b, h, seq_q, seq_k, c, provider): - device = torch.device("cuda", 0) - max_seq_len = seq_q - - scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - - quantiles = [0.5, 0.2, 0.8] - if provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: relative_position_attention_bwd_pos(scores_grad, q, k, max_seq_len), - quantiles=quantiles, - ) - return ms, max_ms, min_ms - - -def test_benchmark(): - benchmark.run(show_plots=False, print_data=True) - - -def test_correctness(): - device = torch.device("cuda", 0) - b = 2 - h = 2 - seq_q = 250 - seq_k = 250 - c = 1025 - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - - q_copy = q.clone() - k_copy = k.clone() - - q.requires_grad_(True) - k.requires_grad_(True) - pos.requires_grad_(True) - - scores0 = relative_position_attention_fwd_torch(q, k, pos) - scores0.retain_grad() - - scale = torch.rand_like(scores0) - - s0 = (scale * scores0).sum() - s0.backward() - print("score0.grad", scores0.grad.shape, scores0.grad.sum()) - print("pos.grad", pos.grad.shape, pos.grad.sum()) - - pos_grad = relative_position_attention_bwd_pos( - scores0.grad, q_copy, k_copy, max_seq_len - ) - - print(pos_grad.shape, pos_grad.sum()) - print((pos.grad - pos_grad).abs().max()) - - -def main(): - # test_benchmark() - test_correctness() - - -if __name__ == "__main__": - torch.manual_seed(20250812) - main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py deleted file mode 100755 index 5a9ececf0c..0000000000 --- a/egs/librispeech/ASR/zapformer2/relative_position_attention_bwd_q_2.py +++ /dev/null @@ -1,332 +0,0 @@ -#!/usr/bin/env python3 -import triton.language as tl -import triton -import torch - - -def get_autotune_config(): - configs = [] - configs.append( - triton.Config( - { - "BLOCK_M": 1, - "BLOCK_N": 32, - "BLOCK_C": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=2, - ) - ) - return configs - - -@triton.autotune( - configs=get_autotune_config()[-2:], - key=["seq_q", "seq_k", "channels", "max_seq_len"], -) -@triton.jit -def relative_position_attention_bwd_q_kernel( - # fmt: off - q_ptr, # (batches, head, seq_q, channel) - k_ptr, # (batches, head, seq_k, channel) - pos_ptr, # (head, 2*max_seq_len-1, channel) - scores_grad_ptr, # (batches, head, seq_q, seq_k) - B, H, seq_q, seq_k, channels, max_seq_len, # shape - stride_qb, stride_qh, stride_qs, stride_qc, # stride for q - stride_kb, stride_kh, stride_ks, stride_kc, # stride for k - stride_ph, stride_ps, stride_pc, # stride for pos - stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores - BLOCK_M: tl.constexpr, # block size in scores_grad - BLOCK_N: tl.constexpr, # block size in channels - BLOCK_C: tl.constexpr, # block size for seq_k - GROUP_SIZE_M: tl.constexpr, # size for grouped block -): - # fmt: on - pid = tl.program_id(axis=0) - pid_bh = tl.program_id(axis=1) - - head = pid_bh % H - batch = pid_bh // H - - num_pid_m = tl.cdiv(seq_q, BLOCK_M) - num_pid_n = tl.cdiv(channels, BLOCK_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_in_group = pid % num_pid_in_group - pid_m = first_pid_m + (pid_in_group % group_size_m) - pid_n = pid_in_group // group_size_m - - tl.assume(BLOCK_M == 1) - - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - tl.assume(stride_qb > 0) - tl.assume(stride_qh > 0) - tl.assume(stride_qs > 0) - tl.assume(stride_qc > 0) - - tl.assume(stride_kb > 0) - tl.assume(stride_kh > 0) - tl.assume(stride_ks > 0) - tl.assume(stride_kc > 0) - - tl.assume(stride_ph > 0) - tl.assume(stride_ps > 0) - tl.assume(stride_pc > 0) - - tl.assume(stride_sb > 0) - tl.assume(stride_sh > 0) - tl.assume(stride_sq > 0) - tl.assume(stride_sk > 0) - - # (BLOCK_M,), we should always set BLOCK_M to 1 - offs_m = pid_m * BLOCK_M - - # (BLOCK_N,) for channels - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # (BLOCK_C,), for seq_k - offs_c = tl.arange(0, BLOCK_C) - - # (BLOCK_N, 1) - offs_n_mask = offs_n[:, None] < channels - - q_base = q_ptr + batch * stride_qb + head * stride_qh - k_base = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_kc - pos_base = pos_ptr + head * stride_ph + offs_n[:, None] * stride_pc - scores_grad_base = ( - scores_grad_ptr + batch * stride_sb + head * stride_sh + offs_m * stride_sq - ) - - acc = tl.zeros((BLOCK_N,), dtype=tl.float32) - - for c in range(0, seq_k, BLOCK_C): - c_idx = c + offs_c - - # (1, BLOCK_C) - rel_idx = offs_m - c_idx[None, :] + max_seq_len - 1 - - # (1, BLOCK_C) - scores_grad_mask = (offs_m < seq_q) & (c_idx[None, :] < seq_k) - - # (BLOCK_N, BLOCK_C) - k_mask = offs_n_mask & (c_idx[None, :] < seq_k) - - # (BLOCK_N, BLOCK_C) - pos_mask = (rel_idx >= 0) & (rel_idx < 2 * max_seq_len - 1) & offs_n_mask - - scores_grad_ptrs = scores_grad_base + c_idx[None, :] * stride_sk - k_ptrs = k_base + c_idx[None, :] * stride_ks - - # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) - scores_grad_chunk = tl.load(scores_grad_ptrs, mask=scores_grad_mask, other=0.0) - - # (BLOCK_N, BLOCK_C) - k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) - - # (BLOCK_N, BLOCK_C) - pos_ptrs = pos_base + rel_idx * stride_ps - - pos_chunk = tl.load(pos_ptrs, mask=pos_mask, other=0.0) - - # scores_grad_chunk (1, BLOCK_C) - # k_chunk (BLOCK_N, BLOCK_C) - # pos_chunk (BLOCK_N, BLOCK_C) - - # kp: (BLOCK_N, BLOCK_C) - kp = k_chunk * pos_chunk - - acc += tl.sum(scores_grad_chunk * kp, axis=1) - - q_ptrs = q_base + offs_m * stride_qs + offs_n * stride_qc - q_mask = (offs_m < seq_q) & (offs_n < channels) - tl.store(q_ptrs, acc, mask=q_mask) - - -def relative_position_attention_bwd_q(scores_grad, k, pos): - """ - Args: - scores_grad: (b, h, seq_q, seq_k) - k: (b, h, seq_k, channels) - pos: (h, 2*max_seq_len-1, channels) - Returns: - grad of q: (b, h, seq_q, channels) - """ - if not scores_grad.is_contiguous(): - scores_grad = scores_grad.contiguous() - - assert scores_grad.is_contiguous(), ( - scores_grad.shape, - scores_grad.stride(0), - scores_grad.stride(1), - scores_grad.stride(2), - scores_grad.stride(3), - ) - assert k.is_contiguous() - assert pos.is_contiguous() - - assert scores_grad.ndim == k.ndim == 4, (scores_grad.shape, k.shape) - assert pos.ndim == 3, pos.shape - b, h, seq_q, seq_k = scores_grad.shape - - c = k.shape[3] - - assert k.shape[0] == b, (k.shape, scores_grad.shape) - assert k.shape[1] == h, (k.shape, scores_grad.shape) - assert k.shape[2] == seq_k, (k.shape, scores_grad.shape) - - assert pos.shape[0] == h, pos.shape - pos.shape[2] == c, pos.shape - - max_seq_len = (pos.shape[1] + 1) // 2 - - assert scores_grad.device == k.device == pos.device, ( - scores_grad.device, - k.device, - pos.device, - ) - - q = torch.empty(b, h, seq_q, c, device=k.device) - - grid = lambda META: ( - triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(c, META["BLOCK_N"]), - b * h, - ) - - # fmt:off - relative_position_attention_bwd_q_kernel[grid]( - q, k, pos, scores_grad, - b, h, seq_q, seq_k, c, max_seq_len, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - pos.stride(0), pos.stride(1), pos.stride(2), - scores_grad.stride(0), scores_grad.stride(1), scores_grad.stride(2), scores_grad.stride(3), - ) - # fmt: on - return q - - -def relative_position_attention_fwd_torch(q, k, pos): - # this function consume a lot of memory, may OOM - max_seq_len = (pos.shape[1] + 1) // 2 - seq_q = q.shape[2] - seq_k = k.shape[2] - - q = q.unsqueeze(3) - k = k.unsqueeze(2) - - i = torch.arange(seq_q, device=q.device).unsqueeze(1) - j = torch.arange(seq_k, device=q.device).unsqueeze(0) - rel = (i - j) + max_seq_len - 1 - rel = rel.clamp(0, pos.shape[1] - 1) - pos_indexed = pos[:, rel].unsqueeze(0) - - # q: (b, h, seq_q, 1, c) - # q: (b, h, 1, seq_k, c) - # pos: (1, h, seq_q, seq_k, c) - scores = (q * k * pos_indexed).sum(dim=-1) - return scores - - -configs = [] -configs.append( - triton.testing.Benchmark( - x_names=[ - "b", - "h", - "seq_q", - "seq_k", - "c", - ], # Argument names to use as an x-axis for the plot - x_vals=[ - (b, h, seq, seq, c) - for b in [1, 2, 3] - # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation - for h in [2, 4] - for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] - for c in [128, 256, 512] - ], - line_arg="provider", # Argument name whose value corresponds to a different line in the plot - line_vals=["triton"], - line_names=["Triton"], - styles=[("green", "-")], - ylabel="time (ms)", # Label name for the y-axis - plot_name="matmul-performance", - args=dict(), - ) -) - - -@triton.testing.perf_report(configs) -def benchmark(b, h, seq_q, seq_k, c, provider): - device = torch.device("cuda", 0) - max_seq_len = seq_q - - k = torch.randn(b, h, seq_k, c, device=device) - - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - - scores_grad = torch.randn(b, h, seq_q, seq_k, device=device) - - quantiles = [0.5, 0.2, 0.8] - if provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: relative_position_attention_bwd_q(scores_grad, k, pos), - quantiles=quantiles, - ) - return ms, max_ms, min_ms - - -def test_benchmark(): - benchmark.run(show_plots=False, print_data=True) - - -def test_correctness(): - device = torch.device("cuda", 0) - b = 2 - h = 2 - seq_q = 250 - seq_k = 250 - c = 1025 - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - - k_copy = k.clone() - pos_copy = pos.clone() - q.requires_grad_(True) - k.requires_grad_(True) - pos.requires_grad_(True) - - scores0 = relative_position_attention_fwd_torch(q, k, pos) - scores0.retain_grad() - - scale = torch.rand_like(scores0) - - s0 = (scale * scores0).sum() - s0.backward() - print("score0.grad", scores0.grad.shape, scores0.grad.sum()) - print("q.grad", q.grad.shape, q.grad.sum()) - - scores_grad = scores0.grad.clone() - q_grad = relative_position_attention_bwd_q(scores_grad, k_copy, pos_copy) - print(q_grad.shape, q_grad.sum()) - print((q.grad - q_grad).abs().max()) - - -def main(): - test_benchmark() - # test_correctness() - - -if __name__ == "__main__": - torch.manual_seed(20250812) - main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py deleted file mode 100755 index e6ea552035..0000000000 --- a/egs/librispeech/ASR/zapformer2/relative_position_attention_fwd_2.py +++ /dev/null @@ -1,302 +0,0 @@ -#!/usr/bin/env python3 -import triton.language as tl -import triton -import torch - - -def get_autotune_config(): - configs = [] - configs.append( - triton.Config( - { - "BLOCK_M": 1, - "BLOCK_N": 32, - "BLOCK_C": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=2, - ) - ) - return configs - - -@triton.autotune( - configs=get_autotune_config()[-2:], - key=["seq_q", "seq_k", "channels", "max_seq_len"], -) -@triton.jit -def relative_position_attention_fwd_kernel( - # fmt: off - q_ptr, # (batches, head, seq_q, channel) - k_ptr, # (batches, head, seq_k, channel) - pos_ptr, # (head, 2*max_seq_len-1, channel) - scores_ptr, # (batches, head, seq_q, seq_k) - B, H, seq_q, seq_k, channels, max_seq_len, # shape - stride_qb, stride_qh, stride_qs, stride_qc, # stride for q - stride_kb, stride_kh, stride_ks, stride_kc, # stride for k - stride_ph, stride_ps, stride_pc, # stride for pos - stride_sb, stride_sh, stride_sq, stride_sk, # stride for scores - BLOCK_M: tl.constexpr, # block size in q - BLOCK_N: tl.constexpr, # block size in k - BLOCK_C: tl.constexpr, # block size for channel - GROUP_SIZE_M: tl.constexpr, # size for grouped block -): - # fmt: on - pid = tl.program_id(axis=0) - pid_bh = tl.program_id(axis=1) - - head = pid_bh % H - batch = pid_bh // H - - num_pid_m = tl.cdiv(seq_q, BLOCK_M) - num_pid_n = tl.cdiv(seq_k, BLOCK_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_in_group = pid % num_pid_in_group - pid_m = first_pid_m + (pid_in_group % group_size_m) - pid_n = pid_in_group // group_size_m - - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - tl.assume(stride_qb > 0) - tl.assume(stride_qh > 0) - tl.assume(stride_qs > 0) - tl.assume(stride_qc > 0) - - tl.assume(stride_kb > 0) - tl.assume(stride_kh > 0) - tl.assume(stride_ks > 0) - tl.assume(stride_kc > 0) - - tl.assume(stride_ph > 0) - tl.assume(stride_ps > 0) - tl.assume(stride_pc > 0) - - tl.assume(stride_sb > 0) - tl.assume(stride_sh > 0) - tl.assume(stride_sq > 0) - tl.assume(stride_sk > 0) - - # (BLOCK_M,), we should always set BLOCK_M to 1 - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # (BLOCK_N,) - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # (BLOCK_C,) - offs_c = tl.arange(0, BLOCK_C) - - # (BLOCK_N, ) - rel_idx = offs_m - offs_n + max_seq_len - 1 - - # (BLOCK_N, 1) - rel_idx_mask = (rel_idx[:, None] >= 0) & (rel_idx[:, None] < 2 * max_seq_len - 1) - - q_ptrs = q_ptr + batch * stride_qb + head * stride_qh + offs_m[:, None] * stride_qs - k_ptrs = k_ptr + batch * stride_kb + head * stride_kh + offs_n[:, None] * stride_ks - - pos_ptrs = pos_ptr + head * stride_ph + rel_idx[:, None] * stride_ps - - acc = tl.zeros((BLOCK_N,), dtype=tl.float32) - - for c in range(0, channels, BLOCK_C): - c_idx = c + offs_c - - # (BLOCK_M, BLOCK_C) - q_mask = (offs_m[:, None] < seq_q) & (c_idx[None, :] < channels) - - # (BLOCK_N, BLOCK_C) - k_mask = (offs_n[:, None] < seq_k) & (c_idx[None, :] < channels) - - # (BLOCK_N, BLOCK_C) - pos_mask = rel_idx_mask & (c_idx[None, :] < channels) - - q_ptrs0 = q_ptrs + c_idx[None, :] * stride_qc - k_ptrs0 = k_ptrs + c_idx[None, :] * stride_kc - - # (BLOCK_M, BLOCK_C), or (1, BLOCK_C) - q_chunk = tl.load(q_ptrs0, mask=q_mask, other=0.0) - - # (BLOCK_N, BLOCK_C) - k_chunk = tl.load(k_ptrs0, mask=k_mask, other=0.0) - - # (BLOCK_N, BLOCK_C) - pos_ptrs0 = pos_ptrs + c_idx[None, :] * stride_pc - - pos_chunk = tl.load(pos_ptrs0, mask=pos_mask, other=0.0) - - # q_chunk (1, BLOCK_C) - # k_chunk (BLOCK_N, BLOCK_C) - # pos_chunk (BLOCK_N, BLOCK_C) - - acc += tl.sum(q_chunk * (k_chunk * pos_chunk), axis=1) - - scores_ptrs = ( - scores_ptr - + batch * stride_sb - + head * stride_sh - + offs_m * stride_sq - + offs_n * stride_sk - ) - scores_mask = (offs_m < seq_q) & (offs_n < seq_k) - - tl.store(scores_ptrs, acc, mask=scores_mask) - - -def relative_position_attention_fwd(q, k, pos): - assert q.is_contiguous() - assert k.is_contiguous() - assert pos.is_contiguous() - - assert q.ndim == k.ndim == 4, (q.shape, k.shape) - assert pos.ndim == 3, pos.shape - b, h, seq_q, c = q.shape - assert k.shape[0] == b, k.shape - assert k.shape[1] == h, k.shape - assert k.shape[3] == c, k.shape - - seq_k = k.shape[2] - - assert pos.shape[0] == h, pos.shape - pos.shape[2] == c, pos.shape - - max_seq_len = (pos.shape[1] + 1) // 2 - - assert q.device == k.device == pos.device, ( - q.device, - k.device, - pos.device, - ) - - scores = torch.empty(b, h, seq_q, seq_k, device=q.device) - - grid = lambda META: ( - triton.cdiv(seq_q, META["BLOCK_M"]) * triton.cdiv(seq_k, META["BLOCK_N"]), - b * h, - ) - - # fmt:off - relative_position_attention_fwd_kernel[grid]( - q, k, pos, scores, - b, h, seq_q, seq_k, c, max_seq_len, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - pos.stride(0), pos.stride(1), pos.stride(2), - scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3), - ) - # fmt: on - return scores - - -def relative_position_attention_fwd_torch(q, k, pos): - # this function consume a lot of memory, may OOM - max_seq_len = (pos.shape[1] + 1) // 2 - seq_q = q.shape[2] - seq_k = k.shape[2] - - q = q.unsqueeze(3) - k = k.unsqueeze(2) - - i = torch.arange(seq_q, device=q.device).unsqueeze(1) - j = torch.arange(seq_k, device=q.device).unsqueeze(0) - rel = (i - j) + max_seq_len - 1 - rel = rel.clamp(0, pos.shape[1] - 1) - pos_indexed = pos[:, rel].unsqueeze(0) - - # q: (b, h, seq_q, 1, c) - # q: (b, h, 1, seq_k, c) - # pos: (1, h, seq_q, seq_k, c) - scores = (q * k * pos_indexed).sum(dim=-1) - return scores - - -configs = [] -configs.append( - triton.testing.Benchmark( - x_names=[ - "b", - "h", - "seq_q", - "seq_k", - "c", - ], # Argument names to use as an x-axis for the plot - x_vals=[ - (b, h, seq, seq, c) - for b in [1, 2, 3] - # for b in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # cause OOM for torch's implementation - for h in [2, 4] - for seq in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] - for c in [128, 256, 512] - ], - line_arg="provider", # Argument name whose value corresponds to a different line in the plot - line_vals=["triton", "torch"], - line_names=["Triton", "Torch"], - styles=[("green", "-"), ("blue", "-")], - ylabel="time (ms)", # Label name for the y-axis - plot_name="matmul-performance with pos", - args=dict(), - ) -) - - -@triton.testing.perf_report(configs) -def benchmark(b, h, seq_q, seq_k, c, provider): - device = torch.device("cuda", 0) - - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - - quantiles = [0.5, 0.2, 0.8] - if provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: relative_position_attention_fwd_torch(q, k, pos), - quantiles=quantiles, - ) - if provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: relative_position_attention_fwd(q, k, pos), quantiles=quantiles - ) - return ms, max_ms, min_ms - - -def test_benchmark(): - benchmark.run(show_plots=False, print_data=True) - - -def test_correctness(): - device = torch.device("cuda", 0) - b = 2 - h = 8 - seq_q = 400 - seq_k = 400 - c = 1024 - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - scores0 = relative_position_attention_fwd_torch(q, k, pos) - scores1 = relative_position_attention_fwd(q, k, pos) - print(scores0.shape, scores0.sum()) - print(scores1.shape, scores1.sum()) - print((scores0 - scores1).abs().max()) - - -def main(): - test_benchmark() - # test_correctness() - - -if __name__ == "__main__": - torch.manual_seed(20250812) - main() diff --git a/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py b/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py deleted file mode 100755 index 21640764ba..0000000000 --- a/egs/librispeech/ASR/zapformer2/relative_position_attention_module_optimized.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 - -import torch - -from relative_position_attention_fwd_2 import ( - relative_position_attention_fwd, - relative_position_attention_fwd_torch, -) - -from relative_position_attention_bwd_q_2 import relative_position_attention_bwd_q -from relative_position_attention_bwd_k_2 import relative_position_attention_bwd_k -from relative_position_attention_bwd_pos_2 import relative_position_attention_bwd_pos - - -class RelativePositionAttentionFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, pos): - """ - Args: - q: (batch, head, seq_q, channel) - k: (batch, head, seq_k, channel) - pos: (head, 2*max_seq_len-1, channel) - Returns: - scores: (batch, head, seq_q, seq_k) - """ - ctx.save_for_backward(q, k, pos) - return relative_position_attention_fwd(q, k, pos) - - @staticmethod - def backward(ctx, scores_grad): - q, k, pos = ctx.saved_tensors - q_grad = None - k_grad = None - pos_grad = None - - if ctx.needs_input_grad[0]: - q_grad = relative_position_attention_bwd_q(scores_grad, k, pos) - - if ctx.needs_input_grad[1]: - k_grad = relative_position_attention_bwd_k(scores_grad, q, pos) - - if ctx.needs_input_grad[2]: - max_seq_len = (pos.shape[1] + 1) // 2 - pos_grad = relative_position_attention_bwd_pos( - scores_grad, q, k, max_seq_len - ) - - return q_grad, k_grad, pos_grad - - -class RelativePositionAttentionModule(torch.nn.Module): - def forward( - self, q: torch.Tensor, k: torch.Tensor, pos: torch.Tensor - ) -> torch.Tensor: - """ - Args: - q: (batch, head, seq_q, channel) - k: (batch, head, seq_k, channel) - pos: (head, 2*max_seq_len-1, channel) - Returns: - scores: (batch, head, seq_q, seq_k) - """ - return RelativePositionAttentionFunction.apply(q, k, pos) - - -def _test(): - torch.manual_seed(20250820) - device = torch.device("cuda", 0) - b = 4 - h = 2 - seq_q = 100 - seq_k = 100 - c = 300 - max_seq_len = seq_q - - q = torch.randn(b, h, seq_q, c, device=device) - k = torch.randn(b, h, seq_k, c, device=device) - pos = torch.randn(h, 2 * max_seq_len - 1, c, device=device) - - q_copy = q.clone() - k_copy = k.clone() - pos_copy = pos.clone() - - q.requires_grad_(True) - k.requires_grad_(True) - pos.requires_grad_(True) - - scores0 = relative_position_attention_fwd_torch(q, k, pos) - - scale = torch.rand_like(scores0) - - s0 = (scale * scores0).sum() - s0.backward() - - q_copy.requires_grad_(True) - k_copy.requires_grad_(True) - pos_copy.requires_grad_(True) - - scores1 = RelativePositionAttentionModule()(q_copy, k_copy, pos_copy) - - s1 = (scale * scores1).sum() - s1.backward() - - print((s0 - s1).max().abs()) - print((q.grad - q_copy.grad).max().abs()) - print((k.grad - k_copy.grad).max().abs()) - print((pos.grad - pos_copy.grad).max().abs()) - """ - tensor(0.0005, device='cuda:0', grad_fn=) - tensor(7.6294e-06, device='cuda:0') - tensor(5.7220e-06, device='cuda:0') - tensor(3.4332e-05, device='cuda:0') - """ - - -if __name__ == "__main__": - _test() - pass diff --git a/egs/librispeech/ASR/zapformer2/scaling.py b/egs/librispeech/ASR/zapformer2/scaling.py deleted file mode 120000 index 58e4b0a0fe..0000000000 --- a/egs/librispeech/ASR/zapformer2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/scaling_converter.py b/egs/librispeech/ASR/zapformer2/scaling_converter.py deleted file mode 120000 index bc7c7b5e37..0000000000 --- a/egs/librispeech/ASR/zapformer2/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/speech_recognition.py b/egs/librispeech/ASR/zapformer2/speech_recognition.py deleted file mode 100755 index dd069cf3da..0000000000 --- a/egs/librispeech/ASR/zapformer2/speech_recognition.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Callable, Dict, List, Union - -import torch -from torch.utils.data.dataloader import DataLoader, default_collate - -from lhotse import validate -from lhotse.cut import CutSet -from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures -from lhotse.utils import compute_num_frames, ifnone -from lhotse.workarounds import Hdf5MemoryIssueFix - - -class K2SpeechRecognitionDataset(torch.utils.data.Dataset): - """ - The PyTorch Dataset for the speech recognition task using k2 library. - - This dataset expects to be queried with lists of cut IDs, - for which it loads features and automatically collates/batches them. - - To use it with a PyTorch DataLoader, set ``batch_size=None`` - and provide a :class:`SimpleCutSampler` sampler. - - Each item in this dataset is a dict of: - - .. code-block:: - - { - 'inputs': float tensor with shape determined by :attr:`input_strategy`: - - single-channel: - - features: (B, T, F) - - audio: (B, T) - - multi-channel: currently not supported - 'supervisions': [ - { - 'sequence_idx': Tensor[int] of shape (S,) - 'text': List[str] of len S - - # For feature input strategies - 'start_frame': Tensor[int] of shape (S,) - 'num_frames': Tensor[int] of shape (S,) - - # For audio input strategies - 'start_sample': Tensor[int] of shape (S,) - 'num_samples': Tensor[int] of shape (S,) - - # Optionally, when return_cuts=True - 'cut': List[AnyCut] of len S - } - ] - } - - Dimension symbols legend: - * ``B`` - batch size (number of Cuts) - * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) - * ``T`` - number of frames of the longest Cut - * ``F`` - number of features - - The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. - """ - - def __init__( - self, - return_cuts: bool = False, - cut_transforms: List[Callable[[CutSet], CutSet]] = None, - input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, - input_strategy: BatchIO = PrecomputedFeatures(), - ): - """ - k2 ASR IterableDataset constructor. - - :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut - objects used to create that batch. - :param cut_transforms: A list of transforms to be applied on each sampled batch, - before converting cuts to an input representation (audio/features). - Examples: cut concatenation, noise cuts mixing, etc. - :param input_transforms: A list of transforms to be applied on each sampled batch, - after the cuts are converted to audio/features. - Examples: normalization, SpecAugment, etc. - :param input_strategy: Converts cuts into a collated batch of audio/features. - By default, reads pre-computed features from disk. - """ - super().__init__() - # Initialize the fields - self.return_cuts = return_cuts - self.cut_transforms = ifnone(cut_transforms, []) - self.input_transforms = ifnone(input_transforms, []) - self.input_strategy = input_strategy - - # This attribute is a workaround to constantly growing HDF5 memory - # throughout the epoch. It regularly closes open file handles to - # reset the internal HDF5 caches. - self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) - - def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: - """ - Return a new batch, with the batch size automatically determined using the constraints - of max_duration and max_cuts. - """ - validate_for_asr(cuts) - - self.hdf5_fix.update() - - # Sort the cuts by duration so that the first one determines the batch time dimensions. - cuts = cuts.sort_by_duration(ascending=False) - - if self.cut_transforms: - orig_cuts = cuts - - cuts = cuts.repeat(times=2) - - for tnfm in self.cut_transforms: - cuts = tnfm(cuts) - - cuts = orig_cuts + cuts - num_copies = 3 - else: - num_copies = 1 - - - # Get a tensor with batched feature matrices, shape (B, T, F) - # Collation performs auto-padding, if necessary. - input_tpl = self.input_strategy(cuts) - if len(input_tpl) == 3: - # An input strategy with fault tolerant audio reading mode. - # "cuts" may be a subset of the original "cuts" variable, - # that only has cuts for which we successfully read the audio. - inputs, _, cuts = input_tpl - else: - inputs, _ = input_tpl - - # Get a dict of tensors that encode the positional information about supervisions - # in the batch of feature matrices. The tensors are named "sequence_idx", - # "start_frame/sample" and "num_frames/samples". - supervision_intervals = self.input_strategy.supervision_intervals(cuts) - - # Apply all available transforms on the inputs, i.e. either audio or features. - # This could be feature extraction, global MVN, SpecAugment, etc. - segments = torch.stack(list(supervision_intervals.values()), dim=1) - for tnfm in self.input_transforms: - inputs = tnfm(inputs, supervision_segments=segments) - - batch = { - "inputs": inputs, - "num_copies": num_copies, - "supervisions": default_collate( - [ - { - "text": supervision.text, - } - for sequence_idx, cut in enumerate(cuts) - for supervision in cut.supervisions - ] - ), - } - # Update the 'supervisions' field with sequence_idx and start/num frames/samples - batch["supervisions"].update(supervision_intervals) - if self.return_cuts: - batch["supervisions"]["cut"] = [ - cut for cut in cuts for sup in cut.supervisions - ] - - has_word_alignments = all( - s.alignment is not None and "word" in s.alignment - for c in cuts - for s in c.supervisions - ) - if has_word_alignments: - # TODO: might need to refactor BatchIO API to move the following conditional logic - # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), - # that returns either num_frames or num_samples depending on the strategy). - words, starts, ends = [], [], [] - frame_shift = cuts[0].frame_shift - sampling_rate = cuts[0].sampling_rate - if frame_shift is None: - try: - frame_shift = self.input_strategy.extractor.frame_shift - except AttributeError: - raise ValueError( - "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " - ) - for c in cuts: - for s in c.supervisions: - words.append([aliword.symbol for aliword in s.alignment["word"]]) - starts.append( - [ - compute_num_frames( - aliword.start, - frame_shift=frame_shift, - sampling_rate=sampling_rate, - ) - for aliword in s.alignment["word"] - ] - ) - ends.append( - [ - compute_num_frames( - aliword.end, - frame_shift=frame_shift, - sampling_rate=sampling_rate, - ) - for aliword in s.alignment["word"] - ] - ) - batch["supervisions"]["word"] = words - batch["supervisions"]["word_start"] = starts - batch["supervisions"]["word_end"] = ends - - return batch - - -def validate_for_asr(cuts: CutSet) -> None: - validate(cuts) - tol = 2e-3 # 1ms - for cut in cuts: - for supervision in cut.supervisions: - assert supervision.start >= -tol, ( - f"Supervisions starting before the cut are not supported for ASR" - f" (sup id: {supervision.id}, cut id: {cut.id})" - ) - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - # - # 'supervision.end' is end of supervision inside the Cut - assert supervision.end <= cut.duration + tol, ( - f"Supervisions ending after the cut " - f"are not supported for ASR" - f" (sup id: {supervision.id}, cut id: {cut.id})" - ) diff --git a/egs/librispeech/ASR/zapformer2/streaming_beam_search.py b/egs/librispeech/ASR/zapformer2/streaming_beam_search.py deleted file mode 120000 index 97e6e733f2..0000000000 --- a/egs/librispeech/ASR/zapformer2/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/streaming_decode.py b/egs/librispeech/ASR/zapformer2/streaming_decode.py deleted file mode 120000 index e31da07d01..0000000000 --- a/egs/librispeech/ASR/zapformer2/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/subsampling.py b/egs/librispeech/ASR/zapformer2/subsampling.py deleted file mode 120000 index d178adc2e5..0000000000 --- a/egs/librispeech/ASR/zapformer2/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_scaling.py b/egs/librispeech/ASR/zapformer2/test_scaling.py deleted file mode 120000 index b776da79a1..0000000000 --- a/egs/librispeech/ASR/zapformer2/test_scaling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/test_subsampling.py b/egs/librispeech/ASR/zapformer2/test_subsampling.py deleted file mode 120000 index 2925ea3c51..0000000000 --- a/egs/librispeech/ASR/zapformer2/test_subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer2/train.py b/egs/librispeech/ASR/zapformer2/train.py deleted file mode 100755 index 4294e139f6..0000000000 --- a/egs/librispeech/ASR/zapformer2/train.py +++ /dev/null @@ -1,1678 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default) - - ctc loss - - attention decoder loss - - cr-ctc loss (should use half the max-duration compared to regular ctc) -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from attention_decoder import AttentionDecoderModel -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Sched3, TransformedAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.exp_augment import ExpAugment # using this, not lhotse's version of nn.Module -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def get_adjusted_lr_batches(params: AttributeDict) -> float: - # returns an adjusted form of the "lr_batches" parameter used to set the learning - # rate in the Sched3 scheduler. - # We want the final LR to be based on the geometric mean of "how much data we - # have seen" and "how many batches we have seen". - # an easier way to look at it is this: the formula for learning rate depends - # on (cur_batch / lr_batches). if we write this as: - # (cur_batch * (duration_ratio ** 0.5)) / params.lr_batches - # then the numerator is a geometric mean of "how many batches we have seen" - # and "how much data we have seen". We can achieve this by setting - # lr_batches = params.lr_batches * (duration_ratio ** -0.5). - duration_ratio = (params.max_duration * params.world_size) / params.ref_duration - lr_batches = params.lr_batches * (duration_ratio ** -0.5) - logging.info(f"Adjusting lr-batches {params.lr_batches} for duration_ratio={duration_ratio} to {lr_batches}") - return lr_batches - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def lookup(params: AttributeDict, name: str): - """ - Interprets numerical arguments in `params` by taking into account base-dim; - also parses comma-separated lists of integers, turning them into tuples. - If a particular attribute ending in "dim" is not present we look up - the same name but ending in "factor", and multiply the elements by base_dim. - """ - try: - attr = getattr(params, name) - try: - attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints - if len(attr) == 1: - attr = attr[0] - except: - pass # leave attr as it is, e.g. a string. - return attr - except AttributeError as e: - if name[-3:] != "dim": - raise e - try: - attr = getattr(params, name[:-3] + "multiple") - if isinstance(attr, str): - attr = tuple(map(int, attr.split(","))) # tuple of ints - base_dim = params.base_dim - attr = tuple([i * base_dim for i in attr]) - if len(attr) == 1: - attr = attr[0] - else: # assume int. - assert isinstance(attr, (int, float)), (name, attr) - attr = attr * params.base_dim - return attr - except AttributeError as e: - raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") - - - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="3,5,6,6,6,5", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--base-dim", - type=int, - default=64, - help="Dimension that, via multiples, defines the dimensions of the model." - ) - - parser.add_argument( - "--embed-multiple", - type=int, - default=6, - help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zipformer stacks and zipformer output dim.", - ) - - parser.add_argument( - "--feedforward-multiple", - type=str, - default="3,3,3,3,3,3", - help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers, per stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-multiple", - type=str, - default="4,6,9,12,9,6", - help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-multiple", - type=int, - default=8, - help="Factor by which embedding dimension in the decoder model is larger than base-dim.", - ) - - parser.add_argument( - "--joiner-multiple", - type=int, - default=8, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--attention-decoder-multiple", - type=int, - default=8, - help="""Factor by which attention decoder dim is larger than base-dim""", - ) - - parser.add_argument( - "--attention-decoder-num-layers", - type=int, - default=6, - help="""Number of transformer layers used in attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-attention-multiple", - type=int, - default=8, - help="""Determines attention dimension used in attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-num-heads", - type=int, - default=8, - help="""Number of attention heads used in attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-feedforward-multiple", - type=int, - default=4, - help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=True, - help="If True, use CTC head.", - ) - - parser.add_argument( - "--use-attention-decoder", - type=str2bool, - default=False, - help="If True, use attention-decoder head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--debug-interval", - type=int, - default=10, - help="""If positive, the interval at which we write various stats to the tensorboard, potentially useful for - finding parts of the network that are diverging or not well trained. - """ - ) - - parser.add_argument( - "--dump-debug-interval", - type=int, - default=0, - help="""If positive, and if debug-interval > 0 the interval at which we dump debug statistics; they - are accumulated at batches with period debug_interval. Should be at least 256 times --debug-interval. - Caution: on remotely mounted file systems this is extremely slow due to quirks of tensorboard (the file - opened, seeked-in and closed for each scalar that is written). - """ - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=17500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--aux-loss-scale", - type=float, - default=0.05, - help="Scale on auxiliary losses that are defined in the model, such " - "as cosine loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--cr-loss-scale", - type=float, - default=0.2, - help="Scale for consistency-regularization loss.", - ) - - parser.add_argument( - "--reconstruction-loss-scale", - type=float, - default=0.005, - help="Final scale for log-mel reconstruction loss (during warmup, use twice this scale).", - ) - - parser.add_argument( - "--predict-loss-scale", - type=float, - default=0.01, - help="Prediction of random k-means after widest zipformer layer" - ) - - parser.add_argument( - "--stochastic-depth-prob", - type=float, - default=0.1, - help="Probability of using a randomly chosen stack output during training, instead of " - "final output." - ) - - parser.add_argument( - "--attention-decoder-loss-scale", - type=float, - default=0.8, - help="Scale for attention-decoder loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--use-bf16", - type=str2bool, - default=False, - help="Whether to use bf16 in AMP.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - warm_step: The warmup period that dictates the decay of the - scale on pruned loss (for transducer) and the reconstruction and prediction - losses. Expressed in terms of the "adjusted batch count", i.e. the - normalized batch count after adjusting for changes in batch size. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - # parameters for attention-decoder - "ignore_id": -1, - "label_smoothing": 0.1, - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=lookup(params, "embed_dim"), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - input_dim=lookup(params, "embed_dim"), - output_downsampling_factor=2, - downsampling_factor=lookup(params, "downsampling_factor"), - num_encoder_layers=lookup(params, "num_encoder_layers"), - encoder_dim=lookup(params, "encoder_dim"), - query_head_dim=lookup(params, "query_head_dim"), - value_head_dim=lookup(params, "value_head_dim"), - num_heads=lookup(params, "num_heads"), - feedforward_multiple=lookup(params, "feedforward_multiple"), - cnn_module_kernel=lookup(params, "cnn_module_kernel"), - dropout=ScheduledFloat((0.0, 0.4), (3000.0, 0.0)), # todo: set to zero - causal=params.causal, - chunk_size=lookup(params, "chunk_size"), - left_context_frames=lookup(params, "left_context_frames"), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=lookup(params, "decoder_dim"), - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - output_downsampling_factor = 2 - joiner = Joiner( - encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, - decoder_dim=lookup(params, "decoder_dim"), - joiner_dim=lookup(params, "joiner_dim"), - vocab_size=params.vocab_size, - ) - return joiner - - -def get_attention_decoder_model(params: AttributeDict) -> nn.Module: - decoder = AttentionDecoderModel( - vocab_size=params.vocab_size, - decoder_dim=lookup(params, "attention_decoder_dim"), - num_decoder_layers=params.attention_decoder_num_layers, - attention_dim=lookup(params, "attention_decoder_attention_dim"), - num_heads=params.attention_decoder_num_heads, - feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), - memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, - sos_id=params.sos_id, - eos_id=params.eos_id, - ignore_id=params.ignore_id, - label_smoothing=params.label_smoothing, - ) - return decoder - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - if params.use_attention_decoder: - attention_decoder = get_attention_decoder_model(params) - else: - attention_decoder = None - - output_downsampling_factor = 2 - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - attention_decoder=attention_decoder, - encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), - decoder_dim=lookup(params, "decoder_dim"), - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - use_attention_decoder=params.use_attention_decoder, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - spec_augment: Optional[nn.Module] = None, - aux_loss_scale: float = 0.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - spec_augment: - The nn.Module instance (or similar object), used for training - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - - texts = batch["supervisions"]["text"] - num_copies = batch["num_copies"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - if num_copies > 1: - assert model.training - # will need the following for time-warping in nn.Module. - supervision_intervals = batch["supervisions"] - supervision_segments = torch.stack( - [ - supervision_intervals["sequence_idx"], - supervision_intervals["start_frame"], - supervision_intervals["num_frames"], - ], - dim=1, - ) # shape: (S, 3) - else: - supervision_segments = None - spec_augment = None # disable spec-aug - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, reconstruction_loss, predict_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - spec_augment=spec_augment, - supervision_segments=supervision_segments, - time_warp_factor=80, # for specaug - num_copies=num_copies, - aux_loss_scale=aux_loss_scale, - sd_prob=(params.stochastic_depth_prob if is_training else 0.0), - ) - - loss = 0.0 - - adjusted_batch_count = params.batch_idx_train - warm_step = params.warm_step - def warmup_schedule(scale, initial_factor): - # geometric warmup schedules. - warmup_factor = (1. if adjusted_batch_count >= warm_step else - initial_factor + (adjusted_batch_count / warm_step) * (1 - initial_factor)) - return scale * warmup_factor - - if params.use_transducer: - simple_loss_scale = params.simple_loss_scale - pruned_loss_scale = warmup_schedule(1.0, 0.05) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - if num_copies > 1: - loss += params.cr_loss_scale * cr_loss - - reconstruction_loss_scale = params.reconstruction_loss_scale - - loss += reconstruction_loss_scale * reconstruction_loss - - if num_copies > 1: - loss += params.predict_loss_scale * predict_loss - - if params.use_attention_decoder: - loss += params.attention_decoder_loss_scale * attention_decoder_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - nframes = (feature_lens // params.subsampling_factor).sum().item() - if num_copies > 1: - nframes = nframes * (num_copies - 1) / num_copies # omit 1st copy - info["frames"] = nframes - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if num_copies > 1: - info["cr_loss"] = cr_loss.detach().cpu().item() - if num_copies > 1: - info["predict_loss"] = predict_loss.detach().cpu().item() - info["recon_loss"] = reconstruction_loss.detach().cpu().item() - if params.use_attention_decoder: - info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - spec_augment: Optional[nn.Module] = None, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - spec_augment: - The SpecAugment or similar instance used for CR-CTC. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def get_scaler_scale(): - if params.use_autocast and scaler._scale is not None: - return scaler._scale.item() - else: - return 1.0 - - def save_bad_model(suffix: str = ""): - if params.debug_interval > 0: - optimizer.write_debug_info(summary_writer=tb_writer) - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.amp.autocast('cuda', - enabled=params.use_autocast, dtype=params.dtype - ): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - spec_augment=spec_augment, - aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: - logging.info(f"Caught exception: {e}.") - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if params.use_autocast: - cur_grad_scale = get_scaler_scale() - - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - if not params.inf_check: - register_inf_check_hooks(model) - logging.warning(f"Grad scale is small: {cur_grad_scale}") - - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or - batch_idx % 100 == 0 and cur_grad_scale < 8.0 or - batch_idx % 400 == 0 and cur_grad_scale < 32.0): - scaler.update(cur_grad_scale * 2.0) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = get_scaler_scale() - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_autocast: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if params.batch_idx_train > 0 and params.dump_debug_interval > 0 and params.batch_idx_train % params.dump_debug_interval == 0: - optimizer.write_debug_info(summary_writer=tb_writer) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.sos_id = params.eos_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - if not params.use_attention_decoder: - params.ctc_loss_scale = 1.0 - else: - assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( - params.ctc_loss_scale, - params.attention_decoder_loss_scale, - ) - - if params.use_bf16: # amp + bf16 - assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" - assert not params.use_fp16, "You can only use either fp16 or bf16" - params.dtype = torch.bfloat16 - params.use_autocast = True - elif params.use_fp16: # amp + fp16 - params.dtype = torch.float16 - params.use_autocast = True - else: # fp32 - params.dtype = torch.float32 - params.use_autocast = False - - logging.info(f"Using dtype={params.dtype}") - logging.info(f"Use AMP={params.use_autocast}") - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - - assert params.use_ctc # for now, require CTC, we may remove this requirement later. - - spec_augment = ExpAugment() - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = TransformedAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - debug_interval=params.debug_interval, - ) - - scheduler = Sched3(optimizer, get_adjusted_lr_batches(params)) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - librispeech = LibriSpeechAsrDataModule(args) - - if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - - # previously we used the following code to load all training cuts, - # strictly speaking, shuffled training cuts should be used instead, - # but we leave the code here to demonstrate that there is an option - # like this to combine multiple cutsets - - # train_cuts = librispeech.train_clean_100_cuts() - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and False: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - spec_augment=spec_augment, - ) - - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - spec_augment=spec_augment, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - d = diagnostic.print_diagnostics() - filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" - torch.save(d, filename) - logging.info(f"Saved detailed diagnostics to {filename}") - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - spec_augment: Optional[nn.Module] = None, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.amp.autocast('cuda', - enabled=params.use_autocast, dtype=params.dtype - ): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - spec_augment=spec_augment, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zapformer2/zipformer.py b/egs/librispeech/ASR/zapformer2/zipformer.py deleted file mode 100644 index f5e1afe779..0000000000 --- a/egs/librispeech/ASR/zapformer2/zipformer.py +++ /dev/null @@ -1,2066 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union -from relative_position_attention_module_optimized import RelativePositionAttentionFunction -import torch -from encoder_interface import EncoderInterface -from scaling import ( - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - OrthogonalLinear, - SimpleOrthogonalLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - ScaleLimiter, - ActivationDropoutAndLinear, - ExpNorm, - ChunkCausalDepthwiseConv1d, - CosineSimilarityLoss, - MinProductLoss, - MaxProductLoss, - Dropout2, - FloatLike, - ScheduledFloat, - Whiten, - convert_num_channels, - limit_param_value, - penalize_abs_values_gt, - softmax, - with_loss, -) -from torch import Tensor, nn - - -class Zipformer2(EncoderInterface): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - - dropout (float): dropout rate - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - """ - def __init__( - self, - input_dim: int, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - query_head_dim: Union[int, Tuple[int]] = 24, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_multiple: Union[int, Tuple[int]] = 4, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - dropout: FloatLike = None, # see code below for default - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - self.output_downsampling_factor = output_downsampling_factor # int - - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_multiple = _to_tuple(feedforward_multiple) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - - # each one will be Zipformer2Encoder or OrthogonalDownsample or OrthogonalUpsample - encoders = [] - - num_encoders = len(downsampling_factor) - - # caution: some changes we made for this break the streaming, later we'll try to fix this. - encoders_downsampling_factors = [ ] - - # make it so large the limit is never reached. - max_proj_dim = max(downsampling_factor) * max(encoder_dim) - - - for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_multiple=feedforward_multiple[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - num_conv_modules=(2 if downsampling_factor[i] == 1 else 1), - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - head_dim=query_head_dim[i], - dim=downsampling_factor[i]*input_dim, - out_proj=False, # (downsampling_factor + (output_downsampling_factor,))[i+1] < downsampling_factor[i], - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - - def forward( - self, - x: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - aux_loss_scale: float = 0.0, - sd_prob: float = 0.0, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - aux_loss_scale: - If supplied, auxiliary losses such as CosineSimilarityLoss will be - applied with this scale on the loss (note, these aux losses are - reduced via summation over frames.) - sd_prob: - Stochastic-depth prob: with this probability we replace the final output - with the output of a randomly chosen stack (including the 'zero stack' which - means the original input x). Each stack except the 'zero stack' has a - separate output projection for stochastic depth, that only sees the - "non-bypass part", i.e. its encoder stack without the residual. - Returns: - Return (embeddings_lengths), where: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - chunk_size, left_context_chunks = self.get_chunk_info() - orig_seq_len = x.shape[0] - - pad = (-orig_seq_len) % max(self.downsampling_factor) - # pad sequence length to be multiple of max(self.downsampling_factor) - x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), - dim=0) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) - - num_stacks = len(self.downsampling_factor) - - x_sd = x - - def randomly_choose_seqs(x, this_x, prob: float): - batch_size = x.shape[1] - do_replace = (torch.rand(1, batch_size, 1, device=x.device) < prob).expand_as(x) - return torch.where(do_replace, this_x, x) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = downsample_by(x, ds) - T = x.shape[0] - x, this_x_sd = module( - x, - chunk_size=chunk_size, - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - attn_mask=(None - if attn_mask is None - else attn_mask[::ds, ::ds] - ), - aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) - ) - x = upsample_by(x, ds) - if sd_prob: - x_sd = randomly_choose_seqs(x_sd, upsample_by(this_x_sd, ds), 1. / (2. + i)) - - - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - od = self.output_downsampling_factor - x = downsample_by(x, od) - x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - if sd_prob: - x_sd = downsample_by(x_sd, od) - x_sd = x_sd[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding - x = randomly_choose_seqs(x, x_sd, sd_prob) - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, chunk_size: int, left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all( - chunk_size * left_context_chunks - >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders) - ) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - new_states += new_layer_states - - x = x[..., :max(self.encoder_dim)] # for historical reasons. can change this. - - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( - device - ) - cached_nonlin_attn = torch.zeros( - 1, batch_size, downsample_left, nonlin_attn_head_dim - ).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - return states - - -def get_max_similarity(rank: int, power: float): - """ - This returns a value for the "max_similarity" argument of CosineSimilarityLoss. - the max_similarity is an upper limit we impose on the mean value of (x_i . x_j) - if i != j are two different sequence-position indexes and x_i and x_j are - activation vectors normalized to have unit length. - - rank: the dimension of the space, usually this is the num_channels, but if - we have just up-projected from a bottleneck, it would be the bottleneck - dimension. - power: a user-tunable value strictly between 0 and 1. If we set power=1.0 it would mean - we enforce the vector dimensions to be completely independent like Gaussian noise - (don't do this); if we set power=0.0 it would be equivalent to not having - the CosineSimilarityLoss at all. - - The factor of 0.797 is sqrt(2/pi) which is the expected absolute value of a normal - variable. If x consists of independent Gaussian noise of dimension D, with - variance 1/D so that the expected 2-norm of x is 1 (so the "normalization to unit length" - would be close to a no-op for large D), then (x_i . x_j) would be distributed as - a Gaussian with variance (D / D^2 = 1/D). So the expected absolute value of (x_i . x_j) - would be sqrt(2/pi * (1/D)). By taking it to the power "power" we just get a value - between this and 1, as a kind of heuristic limit on this max_similarity. - """ - return (0.7978845608 / (rank ** 0.5)) ** power - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -def pad_mask(mask: Optional[Tensor], seq_len: int): - # mask: (batch_size, old_seq_len) - # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). - if mask is None: - return None - (batch_size, old_seq_len) = mask.shape - pad = seq_len - old_seq_len - if pad == 0: - return mask - else: - return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), - dim=1) - - -def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: - # x: (seq_len, batch_size, num_channels) - # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) - (seq_len, batch_size, num_channels) = x.shape - x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) - x = x.permute(0, 2, 1, 3) - x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) - return x - -def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: - # x: (seq_len, batch_size, num_channels) - # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) - (seq_len, batch_size, num_channels) = x.shape - x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) - x = x.permute(0, 2, 1, 3) - x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) - return x - - -def get_dct_matrix(N): - """ - Generates an orthonormal DCT-II matrix for a given size N. - Args: - N (int): The size of the square matrix. - Returns: - torch.Tensor: The N x N orthonormal DCT-II matrix. - """ - # Create the base matrix with dimensions (N, N) - mat = torch.zeros(N, N) - # Create a tensor for the indices k (rows) and n (columns) - k = torch.arange(N).unsqueeze(1) - n = torch.arange(N).unsqueeze(0) - # Fill the matrix using the DCT-II formula - mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) - # Adjust the first row (k=0) with a special normalization factor - mat[0] *= (2 ** -0.5) - return mat - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_multiple: determines the hidden dimension of the feedforward module - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module (default=31). - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - def __init__( - self, - embed_dim: int, - num_heads: int, - query_head_dim: int, - value_head_dim: int, - feedforward_multiple: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - num_conv_modules: int = 2, - causal: bool = False, - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - self.name = None # will be set from training loop - - self.residual_scale = nn.Parameter(0.5 * torch.ones(embed_dim)) - - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=embed_dim, power=0.8)) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - num_heads=2 * num_heads, - query_head_dim=query_head_dim, - dropout=0.0, - ) - - self.self_attn1, self.self_attn2, self.self_attn3 = [ SelfAttention(embed_dim, num_heads, value_head_dim) for _ in range(3) ] - - feedforward_dim = embed_dim * feedforward_multiple - self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) - - if num_conv_modules >= 2: - self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - if num_conv_modules >= 1: - self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - - self.scale_limiter = ScaleLimiter(max_var=2.0) - - self.norm = ExpNorm(embed_dim) - - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - aux_loss_scale: float = 0.0, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, head_dim) or (batch_size, 2*seq_len-1, head_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - aux_loss_scale: - If supplied, auxiliary losses such as CosineSimilarityLoss will be - applied with this scale on the loss (note, these aux losses are - reduced via summation over frames.) - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - aux_loss_scale=0.1 * aux_loss_scale, - ) - num_heads = attn_weights.shape[0] // 2 # num heads per self_attn module - attn_weights1 = attn_weights[:num_heads] - attn_weights2 = attn_weights[num_heads//2:-num_heads//2] - attn_weights3 = attn_weights[num_heads:] - - src = src + self.self_attn1(src, attn_weights1, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - src = src + self.self_attn2(src, attn_weights2, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - if hasattr(self, 'conv_module1'): - src = src + self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - - src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - src = src + self.self_attn3(src, attn_weights3, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - if hasattr(self, 'conv_module2'): - src = src + self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - - src = src + self.feed_forward3(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) - - residual_scale = limit_param_value(self.residual_scale, min=0.1, max=1.0) - offset = (src - src_orig) * residual_scale - src = src_orig + offset - - src = with_loss(src, - self.cosine_loss(offset.permute(1, 0, 2), aux_loss_scale, mask=src_key_padding_mask), - None) - - src = self.scale_limiter(src) - - src = self.norm(src) - - return src - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """Pass the input through the encoder layer in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, - of shape (batch_size, channels, left_pad) - left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - x, with the same shape as src - - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, - ) - src = src + na - - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm(src) - - src = self.residual(src_orig, src) - - src = self.norm(src) - - return ( - src, - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - dim: the dimension of the input and output (layer dim may be less than this). - pos_dim: the dimension for the relative positional encoding -dropout: - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - - - """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dim: int, - head_dim: int, - out_proj: bool, - ) -> None: - super().__init__() - - # self.downsample will also reverse the downsampling operation for us afterward. - self.proj = SimpleOrthogonalLinear(dim, encoder_layer.embed_dim, bias=False) - self.proj.lr_scale = 0.75 - - self.encoder_pos = CompactRelPositionalEncoding( - head_dim, dropout_rate=0.0, length_factor=1.0 - ) - self.name = None - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - self.residual_scales = nn.Parameter( - torch.cat([ -1.0 * torch.ones(1, encoder_layer.embed_dim), - (1. / num_layers) * torch.ones(num_layers, encoder_layer.embed_dim) ], - dim=0)) - - self.copy_bypass = Identity() - - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=dim, power=0.85)) - self.offset_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=encoder_layer.embed_dim, power=0.85)) - - # make penalty_scale disappear after 20k batches; later we can try making this just a normal linear - # module. - if out_proj: - self.out_proj = SimpleOrthogonalLinear(dim, dim, bias=False) - self.out_proj.lr_scale = 0.75 - - # stochastic-depth proj. - self.sd_proj = nn.Linear(encoder_layer.embed_dim, dim) - - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - aux_loss_scale: float = 0.0, - ) -> Tuple[Tensor, Tensor]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), - but embed_dim is allowed to exceed the modules' embed_dim; we will bypass - any extra dimensions. - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - (out, out_sd), both of the same shape as src, - where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. - """ - pos_emb = self.encoder_pos(src) - - src_orig_fulldim = src - - src = self.proj(src) # project to layer dim. - - num_layers = len(self.layers) - src_orig = src - - residual_scale = limit_param_value(self.residual_scales[0], - min=-1.0, max=-0.5) - src_with_bypass = residual_scale * src - - for i, mod in enumerate(self.layers): - src = mod( - src, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - aux_loss_scale=aux_loss_scale/num_layers, - ) - residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.0 if i + 1 < num_layers else 0.1, - max=1.0) - src_with_bypass = src_with_bypass + residual_scale * src - - - offset = src_with_bypass - - src = src_orig_fulldim + self.proj(offset, transpose=True) - # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, - # because of some identities involving orthogonal matrices. - - if aux_loss_scale: - src = with_loss(src, - self.offset_cosine_loss(offset.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask) + - self.cosine_loss(src.permute(1, 0, 2), - aux_loss_scale, src_key_padding_mask), - None) - - src_sd = self.sd_proj(offset) - - if hasattr(self, 'out_proj'): - src = self.out_proj(src) - - return src, src_sd - - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - pos_emb = self.encoder_pos(src, left_context_len) - num_channels = src.shape[-1] - layer_dim = self.layers[0].embed_dim - if num_channels > layer_dim: - src, bypass = src[..., :layer_dim], src[..., layer_dim:] - - new_states = [] - for i, mod in enumerate(self.layers): - ( - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) = states[i * 6 : (i + 1) * 6] - ( - src, - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ) = mod.streaming_forward( - src, - pos_emb, - cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - new_states += [ - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ] - - if num_channels > layer_dim: - src = torch.cat((src, bypass), dim=-1) - - return src, new_states - - -class ResidualModule(nn.Module): - """ - An nn.Module that implements a learnable residual scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - - def __init__( - self, - embed_dim: int, - function_scale_min: FloatLike = 0.1, - ): - super().__init__() - self.function_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.function_scale_min = copy.deepcopy(function_scale_min) - - - def _get_scales(self): - function_scale = self.function_scale - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: - function_scale = limit_param_value( - function_scale, min=float(self.function_scale_min), max=1.0, - ) - residual_scale = 1.0 - function_scale - return residual_scale, function_scale - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - residual_scale, function_scale = self._get_scales() - return residual_scale * src_orig + function_scale * src - - -class OrthogonalDownsample(torch.nn.Module): - """ - Downsamples on sequence axis by appending sequence-positions together, - and then optionally projects by an orthogonal matrix - - - -. Projection is initialized - in a special way and enforced to be orthogonal. - - Args: - channels: the number of input channels; the num output channels will be twice this - proj_dim: the number of channels, after combining 2 frames by interpolating their channels - as [ a b a b, .. ] that will actually be projected; the rest are just copied. - proj_dim=2 * channels would mean all channels are projected in a learned way - causal: True for causal systems, only affects error messages as requires even - input num frames. - """ - def __init__( - self, channels: int, proj_dim: int, causal: bool = False, - ): - super().__init__() - assert proj_dim <= channels * 2 - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 - self.causal = causal - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - - if seq_len % 2 == 1: - if torch.jit.is_tracing(): - assert ( - not self.causal - ), f"pad should be zero for exporting streaming models. Given {pad}" - src = torch.cat((src, src[-1:]), dim=0) - seq_len += 1 - - # the following will place each 2 frames of a particular channel right after - # each other as if they were two different channels. - src = torch.stack((src[0::2], src[1::2]), dim=-1) - src = src.reshape(seq_len // 2, batch_size, in_channels * 2) - proj_channels = self.proj.weight.shape[0] - if proj_channels < in_channels * 2: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - return src - -class OrthogonalUpsample(torch.nn.Module): - """ - A very simple form of upsampling with an orthogonal matrix. - - proj_dim: the number of channels that will actually be projected; the rest are just copied. - proj_dim=channels would mean all channels are projected in a learned way - - """ - def __init__(self, channels: int, proj_dim: int): - super().__init__() - assert proj_dim <= channels - # gradually make smaller and then turn off the non-orthognality penalty. - self.proj = OrthogonalLinear(proj_dim, proj_dim, bias=False, - penalty_scale=ScheduledFloat((0.0, 20.0), (5000.0, 1.0), (10000.0, 0.1), (20000.0, 0.0))) - # lr_scale is a learning-rate factor to slow down how fast self.proj is learned. - # it will be interpreted by get_parameter_groups_with_lrs() - self.proj.lr_scale = 0.75 - - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*2), batch_size, num_channels // 2) - """ - proj_channels = self.proj.weight.shape[0] - (seq_len, batch_size, in_channels) = src.shape - - if proj_channels < in_channels: - src = torch.cat((self.proj(src[..., :proj_channels]), src[..., proj_channels:]), - dim=-1) - else: - src = self.proj(src) - - src = torch.stack((src[..., 0::2], src[..., 1::2]), - dim=1) # (seq_len, 2, batch_size, in_channels // 2) - src = src.reshape(seq_len * 2, batch_size, in_channels // 2) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the Fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embed_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - - def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0, embed_dim - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0, length_factor - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - dropout: dropout probability for attn_output_weights. Default: 0.0. - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - query_head_dim: int, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.dropout = dropout - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, - bias=True, initial_scale=0.125 * query_head_dim**-0.25 - ) - - - self.key_cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=key_head_dim, power=0.5)) - - - # the following are for diagnostics only, see --print-diagnostics option - self.copy_query = Identity() - self.copy_key = Identity() - - self.qk_max_product = MaxProductLoss(max_product=ScheduledFloat((0.0, 0.6), (20000.0, 6.0), default=5.0)) - - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - aux_loss_scale: float = 0.0, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, head_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.copy_key(k) - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - if aux_loss_scale: - k = with_loss(k, - self.key_cosine_loss(k.permute(1, 2, 0, 3).reshape(batch_size * num_heads, seq_len, query_head_dim), - aux_loss_scale / num_heads, - key_padding_mask.repeat_interleave(num_heads, dim=0) if key_padding_mask is not None else None), - None) - - - # time1 refers to target, time2 refers to source. - q = q.permute(1, 2, 0, 3) # (batch, head, time1, query_head_dim) - k = k.permute(1, 2, 0, 3) # (batch, head, time2, query_head_dim) - - if self.training: - k = with_loss(k, - self.qk_max_product(q.reshape(batch_size * num_heads, seq_len, query_head_dim), - k.reshape(batch_size * num_heads, seq_len, query_head_dim), - aux_loss_scale / num_heads), - None) - - - attn_scores = RelativePositionAttentionFunction.apply(q.contiguous(), k.contiguous(), pos_emb.repeat(num_heads, 1, 1)) - - - assert attn_scores.shape == (batch_size, num_heads, seq_len, seq_len) - attn_scores = attn_scores.permute(1, 0, 2, 3) - # (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, ( - cached_key.shape[0], - left_context_len, - ) - k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - - # HERE.. not finished streaming code. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(k_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, k_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == ( - num_heads, - batch_size, - seq_len, - k_len, - ), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - attn_weights = attn_scores.softmax(dim=-1) - - return attn_weights, cached_key - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = OrthogonalLinear(embed_dim, num_heads * value_head_dim, - bias=True, out_groups=num_heads) - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 - ) - - f = max(1.0, embed_dim / (num_heads * value_head_dim)) - - self.cosine_loss = CosineSimilarityLoss(max_similarity=ScheduledFloat((0.0, 0.5), (20000.0, 0.75), default=0.5)) - - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - aux_loss_scale: float = 0.0, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - src_key_padding_mask: optional Tensor of shape (batch_size, src_seq_len); only - used for the cosine similarity loss, during training. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # x: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - if aux_loss_scale: - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), - aux_loss_scale, - mask=src_key_padding_mask), None) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape[0], - left_context_len, - ) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model.""" - - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - # try to get in the useful range of the activation function, i.e. not too small. - self.in_proj = ScaledLinear(embed_dim, feedforward_dim) - # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed - # to the TransformedAdam optimizer. - self.in_proj.weight_min_rms = 0.02 - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwashL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.5, - ) - - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(embed_dim, feedforward_dim), power=0.7)) - - - def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: - x = self.in_proj(x) - x = self.out_proj(x) - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), None) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == ( - num_heads, - batch_size, - seq_len, - left_context_len + seq_len, - ) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, ( - cached_x.shape[2], - left_context_len, - ) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) - if causal - else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - ) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwashR", - dropout_p=0.0, - initial_scale=0.05, - ) - self.cosine_loss = CosineSimilarityLoss(get_max_similarity(rank=min(channels, bottleneck_dim), power=0.6)) - - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - aux_loss_scale: float = 0.0, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if ( - not torch.jit.is_scripting() - and not torch.jit.is_tracing() - and chunk_size >= 0 - ): - # Not support exporting a model for simulated streaming decoding - assert ( - self.causal - ), "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - x = with_loss(x, self.cosine_loss(x.permute(1, 0, 2), aux_loss_scale, src_key_padding_mask), - None) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - seq_len = 20 - # Just make sure the forward pass runs. - - input_dim = 50 - - c = Zipformer2( - input_dim=input_dim, - encoder_dim=(64, 96), - num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,), - ) - - batch_size = 6 - seq_len = 21 - # Just make sure the forward pass runs. - f, lengths = c( - torch.randn(seq_len, batch_size, input_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - aux_loss_scale=1.0, - sd_prob=0.1, - ) - f.sum().backward() - c.eval() - x_ = c( - torch.randn(seq_len, batch_size, input_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - aux_loss_scale=1.0, - sd_prob=0.1, - ) - x_ # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) From 6961d60b7e9305b9719b23378fd6a88d7365d5e0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Apr 2026 14:59:23 +0800 Subject: [PATCH 1008/1191] Remove unnecessary change to get_parameter_groups_with_lrs() from icefall/utils.py --- icefall/utils.py | 68 ++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 46 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index d3537712a0..ca682a0326 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -648,7 +648,7 @@ def store_translations( hyp_list = [] ref_list = [] dir_ = os.path.dirname(filename) - reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) + reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename))) hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename))) bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename))) @@ -661,7 +661,7 @@ def store_translations( print(f"{cut_id}: ref_tgt {ref_tgt}", file=f) print(f"{cut_id}: hyp {hyp}", file=f) print("\n", file=f) - + print(f"{ref}", file=f_src) print(f"{ref_tgt}", file=f_tgt) @@ -673,7 +673,7 @@ def store_translations( with open(bleu_file, 'w') as b: print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b) print(f"BLEU signiture: {str(bleu.get_signature())}", file=b) - + logging.info( f"[{bleu.corpus_score(hyp_list, [ref_list])}] " f"BLEU signiture: {str(bleu.get_signature())}" @@ -1582,16 +1582,12 @@ def get_parameter_groups_with_lrs( lr: float, include_names: bool = False, freeze_modules: List[str] = [], - attrs: List[str] = ['lr_scale', 'weight_min_rms', 'bias_min_rms', 'weight_max_rms', 'bias_max_rms', 'scale_default'], ) -> List[dict]: """ - This is to automatically create parameter-groups with overrides of parameter optimizer - settings, especially the learning rate which can be scaled using the "lr_scale" attribut - in modules, but also other possible configuration values that you may specify. - + This is for use with the ScaledAdam optimizers (more recent versions that accept lists of + named-parameters; we can, if needed, create a version without the names). - It provides a way to specify learning-rate scales and other optimizer configuration - settings inside the module, so that if + It provides a way to specify learning-rate scales inside the module, so that if any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will scale the LR of any parameters inside that module or its submodules. Note: you can set module parameters outside the __init__ function, e.g.: @@ -1611,27 +1607,20 @@ def get_parameter_groups_with_lrs( """ named_modules = list(model.named_modules()) - # flat_lr_scale[prefix] for a prefix like 'encoder.layers.3' contains - # a dict with all the optimizer configuration settings specified at this level. - # these need to be combined for all prefixes of the name of any given parameter. - flat_config = defaultdict(dict) + # flat_lr_scale just contains the lr_scale explicitly specified + # for each prefix of the name, e.g. 'encoder.layers.3', these need + # to be multiplied for all prefix of the name of any given parameter. + flat_lr_scale = defaultdict(lambda: 1.0) names = [] for name, m in model.named_modules(): names.append(name) - for attr in attrs: # we can add more here as needed - try: - # getattr(m, attr) if attr == 'lr_scale' is equivalent to m.lr_scale - flat_config[name][attr] = getattr(m, attr) - except AttributeError: - pass + if hasattr(m, "lr_scale"): + flat_lr_scale[name] = m.lr_scale - - # lr_to_parames is a dict from config-string to: + # lr_to_parames is a dict from learning rate (floating point) to: if # include_names == true, a list of (name, parameter) for that learning rate; # otherwise a list of parameters for that learning rate. - # The config-string is the repr(dict) for the dictionary of attributes combined - # over all prefixes of that parameter name. - config_to_params = defaultdict(list) + lr_to_params = defaultdict(list) for name, parameter in model.named_parameters(): split_name = name.split(".") @@ -1646,30 +1635,18 @@ def get_parameter_groups_with_lrs( if prefix in freeze_modules: logging.info(f"Remove {name} from parameters") continue - - cur_config = dict() - cur_config.update(flat_config[prefix]) # include dict items from here. + cur_lr = lr * flat_lr_scale[prefix] if prefix != "": - cur_config.update(flat_config[""]) + cur_lr *= flat_lr_scale[""] for part in split_name[1:]: prefix = ".".join([prefix, part]) - cur_config.update(flat_config[prefix]) - + cur_lr *= flat_lr_scale[prefix] + lr_to_params[cur_lr].append((name, parameter) if include_names else parameter) - config_to_params[repr(cur_config)].append((name, parameter) if include_names else parameter) - - - ans = [ ] - for config, params in config_to_params.items(): - config = eval(config) # turn from string back into dict. - try: # turn "lr_scale" into "lr" - config["lr"] = lr * config["lr_scale"] - del config["lr_scale"] - except KeyError: - pass - config["named_params" if include_names else "params"] = params - ans.append(config) - return ans + if include_names: + return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()] + else: + return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()] def optim_step_and_measure_param_change( @@ -2490,5 +2467,4 @@ def time_warp( features[sequence_idx, :num_frames], factor=time_warp_factor ) - return features From b958bf1f0c4281d9e52ae57b7edfdf8548c60df0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Apr 2026 16:26:58 +0800 Subject: [PATCH 1009/1191] Set bias_scale_limits to be the same as weight_scale_limits: (0.05,0.25) --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/rubik.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 4fbe5cd48e..b511371bf3 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -387,7 +387,7 @@ def __init__( beta2=0.98, eps=1.0e-16, weight_scale_limits=(0.05, 0.25), - bias_scale_limits=(0.1, 0.5), + bias_scale_limits=(0.05, 0.25), scalar_scale=0.075, ): diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 1840e4c943..9b2ad4bde4 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -271,7 +271,7 @@ def __init__( beta2=0.98, eps=1.0e-16, weight_scale_limits=(0.05, 0.25), - bias_scale_limits=(0.1, 0.5), + bias_scale_limits=(0.05, 0.25), scalar_scale=0.075, ): defaults = dict( From 5c5e21b2d09d243bc44b1d005bea41420a49acda Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 4 Apr 2026 21:41:05 +0800 Subject: [PATCH 1010/1191] Change the num hours to num cuts in the weights of subsets. --- egs/librispeech/ASR/zapformer/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 80db5b928b..49fca32f7b 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1420,7 +1420,8 @@ def lr_lambda(current_step): if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() - train_cuts_len = 960.0 * 3 # 960 hours times 3 for augmentation + # train_cuts_len = 960.0 * 3 # 960 hours times 3 for augmentation + train_cuts_len = 843723 * 3 # previously we used the following code to load all training cuts, # strictly speaking, shuffled training cuts should be used instead, @@ -1432,7 +1433,7 @@ def lr_lambda(current_step): # train_cuts += librispeech.train_other_500_cuts() else: train_cuts = librispeech.train_clean_100_cuts() - train_cuts_len = 100.0 * 3 # 100 hours times 3 for speed augmentation + train_cuts_len = 85617 * 3 # 100.0 * 3 if params.use_giga or params.use_cv: if params.libri_copies > 1: @@ -1443,10 +1444,10 @@ def lr_lambda(current_step): if params.use_giga: if params.full_libri: gigaspeech_cuts = gigaspeech.train_XL_cuts() - gigaspeech_cuts_len = 10000.0 + gigaspeech_cuts_len = 8277188 # 10000.0 else: gigaspeech_cuts = gigaspeech.train_S_cuts() # e.g. for debugging - gigaspeech_cuts_len = 250.0 + gigaspeech_cuts_len = 229394 # 250.0 datasets_and_weights.append((gigaspeech_cuts, gigaspeech_cuts_len)) if params.use_cv: @@ -1455,7 +1456,7 @@ def normalize_text(c): c.supervisions[0].text = re.sub(r'[^\w\s]', '', c.supervisions[0].text).upper() return c commonvoice_cuts = commonvoice.train_cuts().map(normalize_text) - commonvoice_cuts_len = 2600.0 + commonvoice_cuts_len = 1822817 # 2600.0 datasets_and_weights.append((commonvoice_cuts, commonvoice_cuts_len)) cuts, weights = zip(*datasets_and_weights) From af080b9dbd8652674ecdc5b33ac237ae16e2fd63 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Apr 2026 16:04:34 +0800 Subject: [PATCH 1011/1191] Bug fix regarding length of libri cuts --- egs/librispeech/ASR/zapformer/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 49fca32f7b..9a8233b307 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1421,7 +1421,7 @@ def lr_lambda(current_step): if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() # train_cuts_len = 960.0 * 3 # 960 hours times 3 for augmentation - train_cuts_len = 843723 * 3 + train_cuts_len = 843723 # includes 3x speed perturbation # previously we used the following code to load all training cuts, # strictly speaking, shuffled training cuts should be used instead, @@ -1433,7 +1433,7 @@ def lr_lambda(current_step): # train_cuts += librispeech.train_other_500_cuts() else: train_cuts = librispeech.train_clean_100_cuts() - train_cuts_len = 85617 * 3 # 100.0 * 3 + train_cuts_len = 85617 # includes 3x speed perturbation if params.use_giga or params.use_cv: if params.libri_copies > 1: From 7a8d0dbd68abdd2a4616d76baa1898df65816829 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Apr 2026 15:41:03 +0800 Subject: [PATCH 1012/1191] Increase central num layers from 12 to 14. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 80db5b928b..8da2f135da 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -181,7 +181,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="6,8,12,8", + default="6,8,14,8", help="Number of zapformer encoder layers per stack, comma separated.", ) From a65c190ac9c0befc7c8bf26280c629cd680e0bea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Apr 2026 17:22:32 +0800 Subject: [PATCH 1013/1191] Add min_factor = 0.05 to InterpCosineLRScheduler, applied via interpolation --- egs/librispeech/ASR/zapformer/combined_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index cd6a2822ad..5962a4ae0a 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -181,6 +181,7 @@ def get_lr(self): class InterpCosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, + min_factor: float = 0.05, **kwargs): """ This cosine LR scheduler is halfway between the conventional cosine LR scheduler @@ -188,6 +189,7 @@ def __init__(self, It inherits from CombinedLRScheduler (see its documentation to understand general aspects of usage). """ + self.min_factor = min_factor super().__init__(*args, **kwargs) def get_lr(self): @@ -196,6 +198,7 @@ def get_lr(self): # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate # between the two. factor = 0.5 * (factor + factor ** 2) + factor = self.min_factor + factor * (1. - self.min_factor) return [x * factor for x in self.base_lrs] From b1d77fbeb5ca7f399ad245ade608b4a2ad0e0c98 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Apr 2026 17:30:26 +0800 Subject: [PATCH 1014/1191] Remove unnecessarily num_copies-related code and comments. --- egs/librispeech/ASR/zapformer/train.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index a0f26ed321..23179b60b3 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -962,21 +962,13 @@ def compute_loss( batch_idx_train = params.batch_idx_train texts = batch["supervisions"]["text"] - num_copies = batch["num_copies"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) if is_training: - # the num_copies thing is actually not very important any more, you can remove - # the assertion if it's a problem in future. (previously we used losses that - # required the different copies to be in sync on the time dimension, e.g. - # to use the same time warping; we don't do this any more.) - #assert num_copies == 2 batch_size = features.shape[0] features = augmentation(features, feature_lens) - else: - assert num_copies == 1 with torch.set_grad_enabled(is_training): From 92bb2852f1327456611d01e3d626796c77fce7ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 11:35:51 +0800 Subject: [PATCH 1015/1191] Initialize out_proj scales of submodules to zero. --- egs/librispeech/ASR/zapformer/zapformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index a5ad2df8a8..1eed428ad4 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -938,7 +938,7 @@ def __init__( # out proj for the value times gating. self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.5 + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.0 ) self.weighted_mean = WeightedMean(num_heads * value_head_dim, causal) # TODO: fix causal option @@ -1252,7 +1252,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int): feedforward_dim, embed_dim, activation="SwashR", - initial_scale=0.5, + initial_scale=0.0, bias=True, ) @@ -1820,7 +1820,7 @@ def __init__( bottleneck_dim, channels, activation="SwashR", - initial_scale=0.05, + initial_scale=0.0, ) def forward( From d8614fc2b5d691fedf0a522eade2e8b141782c3b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 11:47:17 +0800 Subject: [PATCH 1016/1191] Revert zero-out-proj-scales and instead decrease out_proj initial scale of self-attention from 0.5 to 0.1. --- egs/librispeech/ASR/zapformer/zapformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 1eed428ad4..f0284c36b6 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -938,7 +938,7 @@ def __init__( # out proj for the value times gating. self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.0 + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.1 ) self.weighted_mean = WeightedMean(num_heads * value_head_dim, causal) # TODO: fix causal option @@ -1252,7 +1252,7 @@ def __init__(self, embed_dim: int, feedforward_dim: int): feedforward_dim, embed_dim, activation="SwashR", - initial_scale=0.0, + initial_scale=0.5, bias=True, ) @@ -1820,7 +1820,7 @@ def __init__( bottleneck_dim, channels, activation="SwashR", - initial_scale=0.0, + initial_scale=0.05, ) def forward( From e9b20790bb96388d5359c9ca82fa14477b4d903b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 14:43:56 +0800 Subject: [PATCH 1017/1191] Make num_copies rise linearly with epoch, starting from 1, to --max-copies --- .../ASR/zapformer/asr_datamodule.py | 20 +++++----------- egs/librispeech/ASR/zapformer/train.py | 23 ++++++++++++++++--- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 29842c72fc..2e5b2b0cae 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -118,7 +118,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): type=float, default=800.0, help="Maximum pooled recordings duration (seconds) in a " - "single batch, including the --num-copies argument, so if --num-copies " + "single batch, including multiple copies, so if num_copies " "is larger the actual duration prior to making copies will be smaller." ) group.add_argument( @@ -210,15 +210,6 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="AudioSamples or PrecomputedFeatures", ) - group.add_argument( - "--num-copies", - type=int, - default=4, - help="The number of copies of each training example selected in each batch (they will be augmented " - "differently). If you make num-copies larger there will be more steps per epoch so you should probably make " - "num-epochs smaller. " - ) - parser.add_argument( "--libri-copies", type=int, @@ -245,6 +236,7 @@ def train_dataloaders( self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + num_copies: int = 1, ) -> DataLoader: """ Args: @@ -280,7 +272,7 @@ def train_dataloaders( logging.info("About to create train dataset") train = MulticopyDataset( - num_copies=self.args.num_copies, + num_copies=num_copies, input_strategy=eval(self.args.input_strategy)(), cut_transforms=transforms, input_transforms=[], @@ -299,7 +291,7 @@ def train_dataloaders( # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. train = MulticopyDataset( - num_copies=self.args.num_copies, + num_copies=num_copies, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, @@ -310,7 +302,7 @@ def train_dataloaders( logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, - max_duration=self.args.max_duration / self.args.num_copies, + max_duration=self.args.max_duration / num_copies, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, buffer_size=self.args.num_buckets * 2000, @@ -321,7 +313,7 @@ def train_dataloaders( logging.info("Using SimpleCutSampler.") train_sampler = SimpleCutSampler( cuts_train, - max_duration=self.args.max_duration / self.args.num_copies, + max_duration=self.args.max_duration / num_copies, shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 23179b60b3..c526f8d4a1 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -386,6 +386,13 @@ def get_parser(): help="Number of epochs to train.", ) + parser.add_argument( + "--max-copies", + type=int, + default=8, + help="The num_copies to use in the dataloader on the last epoch (it rises linearly)" + ) + parser.add_argument( "--batches-per-epoch", type=int, @@ -1500,15 +1507,16 @@ def remove_short_and_long_utt(c: Cut): else: sampler_state_dict = None - train_dl = asr_datamodule.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) if not params.print_diagnostics and False: + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + num_copies=1, + ) scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, @@ -1524,6 +1532,15 @@ def remove_short_and_long_utt(c: Cut): for epoch in range(params.start_epoch, params.num_epochs + 1): fix_random_seed(params.seed + epoch - 1) + + num_copies = 1 + round((params.max_copies - 1) * epoch / params.num_epochs) + logging.info("On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + num_copies=num_copies, + ) + sampler_state_dict=None train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: From 9e3c29a14afa2f512d4504404bef3a0e7ee55100 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 15:35:29 +0800 Subject: [PATCH 1018/1191] Implement more accurate LR schedule with variable_combined_scheduler --- egs/librispeech/ASR/zapformer/train.py | 23 +-- .../zapformer/variable_combined_scheduler.py | 163 ++++++++++++++++++ 2 files changed, 176 insertions(+), 10 deletions(-) create mode 100644 egs/librispeech/ASR/zapformer/variable_combined_scheduler.py diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index c526f8d4a1..1a72ed29df 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -83,9 +83,9 @@ pass -from combined_scheduler import CombinedLRScheduler +from variable_combined_scheduler import VariableCombinedLRScheduler try: - from combined_scheduler import InterpCosineLRScheduler + from variable_combined_scheduler import InterpCosineLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -396,9 +396,10 @@ def get_parser(): parser.add_argument( "--batches-per-epoch", type=int, - default=4550, + default=2200, help="Assumed number of batches per epoch for purposes of setting learning rate; only " - "makes a difference during the first batch, after which an observed value is used.." + "makes a difference during the first batch, after which an observed value is used. This " + "is the num batches where num_copies==1, i.e. on the first epoch" ) @@ -787,7 +788,7 @@ def load_checkpoint_if_available( model: nn.Module, model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[CombinedLRScheduler] = None, + scheduler: Optional[VariableCombinedLRScheduler] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -853,7 +854,7 @@ def save_checkpoint( model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[CombinedLRScheduler] = None, + scheduler: Optional[VariableCombinedLRScheduler] = None, sampler: Optional[CutSampler] = None, scaler: Optional[GradScaler] = None, rank: int = 0, @@ -1069,7 +1070,7 @@ def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, - scheduler: CombinedLRScheduler, + scheduler: VariableCombinedLRScheduler, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -1387,9 +1388,11 @@ def lr_lambda(current_step): progress = current_step / total_steps return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + def get_num_copies(epoch): + return 1 + round((params.max_copies - 1) * epoch / params.num_epochs) scheduler = InterpCosineLRScheduler(optimizer, - batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs) + batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1533,7 +1536,7 @@ def remove_short_and_long_utt(c: Cut): for epoch in range(params.start_epoch, params.num_epochs + 1): fix_random_seed(params.seed + epoch - 1) - num_copies = 1 + round((params.max_copies - 1) * epoch / params.num_epochs) + num_copies = get_num_copies(epoch) logging.info("On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") train_dl = asr_datamodule.train_dataloaders( train_cuts, diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py new file mode 100644 index 0000000000..5be7217d62 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -0,0 +1,163 @@ +import torch +from torch import Tensor +from torch.optim import Optimizer +from typing import List +import math +import logging + +class VariableCombinedLRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch; in this version the expected number of batches can be different + for different epochs. + + + base_batches = 3100 + multiples = [ 1, 1, 1, 2, 2, 2, 3, 3, 3 ] + batches_per_epoch = [ m * base_batches for m in multiples ] + + scheduler = InterpCosineLRScheduler(optimizer, batches_per_epoch=batches_per_epoch) + for epoch in range(len(multiples)): + scheduler.set_epoch(epoch+1) # caution: one-based epoch count + train_dl = f(multiples[epoch]) # num batches propto multiples. + for batch_idx, batch in enumerate(train_dl): # train_dl expected + scheduler.set_batch_idx(batch_idx) + + Args: + optimizer: optimizer that we will set the learning rates in; the initial learning rate(s) in + the optimizer is/are the base LRs and we set the LR as a fraction of those. + batches_per_epoch: the estimated number of batches per epoch; use your best guess. + num_epochs: the total number of epochs you will train for + """ + def __init__(self, + optimizer: Optimizer, + batches_per_epoch: List[int], + verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.batches_per_epoch = list(batches_per_epoch) # copy the list in case it's modified + self.tot_batches = sum(self.batches_per_epoch) + self.adjust_factor = 1.0 + + self.epoch = -1 + self.batch = -1 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + "adjust_factor": self.adjust_factor, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def set_batch(self, batch: int): + """ Sets the batch index within the epoch, with zero-based counting (not that this matters much).""" + # set the within-epoch batch index. + self.batch = batch + self._set_lrs() + + def set_epoch(self, epoch: int): + """ Sets the epoch with one-based counting, so the first epoch is 1; the epoch should not exceed the num_epochs used + in the constructor. """ + assert epoch > 0 and epoch <= len(self.batches_per_epoch) # Epoch numbers are assumed to be be 1-based indexes. + if epoch == self.epoch + 1 and self.batch > 0 and self.epoch > 0: + self.adjust_factor = self.batch / self.batches_per_epoch[self.epoch-1] + logging.info(f"Setting self.adjust_factor = {adjust_factor} = observed/expected batches {self.batch}/{self.batches_per_epoch[self.epoch-1]} on epoch {self.epoch}") + + self.epoch = epoch + self.past_batches = sum(self.batches_per_epoch[:epoch-1], start=0) + self._set_lrs() + + def get_progress(self): + if self.epoch <= 0: + return 0.0 + else: + # epoch indexes start from 1 so we have to subtract 1 before indexing self.batches_per_epoch + past_batches = self.past_batches # sum of batches on previous eopchs + tot_batches = self.tot_batches # anticipated total batches + cur_max_batches = self.batches_per_epoch[self.epoch - 1] + cur_batches = min(cur_max_batches, self.adjust_factor * self.batch) + + progress = (past_batches + cur_batches) / tot_batches + assert progress <= 1.0 + return progress + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.warning( + f"Epoch={self.epoch}, batch={self.batch}, num_epochs={self.num_epochs}, batches_per_epoch={self.batches_per_epoch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + + + +class InterpCosineLRScheduler(VariableCombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.05, + **kwargs): + """ + This cosine LR scheduler is halfway between the conventional cosine LR scheduler + that takes the cosine from 0 to pi, and one that takes the cosine from 0 to pi/2. + It inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). + """ + self.min_factor = min_factor + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + factor = math.cos((math.pi / 2) * progress) + # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate + # between the two. + factor = 0.5 * (factor + factor ** 2) + factor = self.min_factor + factor * (1. - self.min_factor) + return [x * factor for x in self.base_lrs] From 0e865c4f4c5e16214855a79d1a6143b12324974c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 16:17:07 +0800 Subject: [PATCH 1019/1191] Fix schedule of num_copies, start from 1 not 2. --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 4 ++-- egs/librispeech/ASR/zapformer/train.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 2e5b2b0cae..7e91edac5d 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -299,7 +299,7 @@ def train_dataloaders( ) if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") + logging.info(f"Using DynamicBucketingSampler, num_copies={num_copies}") train_sampler = DynamicBucketingSampler( cuts_train, max_duration=self.args.max_duration / num_copies, @@ -310,7 +310,7 @@ def train_dataloaders( drop_last=self.args.drop_last, ) else: - logging.info("Using SimpleCutSampler.") + logging.info(f"Using SimpleCutSampler, num_copies={num_copies}") train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration / num_copies, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 1a72ed29df..b4ff98091a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1390,7 +1390,8 @@ def lr_lambda(current_step): def get_num_copies(epoch): - return 1 + round((params.max_copies - 1) * epoch / params.num_epochs) + # num_epochs arg is one-based. + return max(1, params.max_copies * epoch / params.num_epochs) scheduler = InterpCosineLRScheduler(optimizer, batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) From 4758c6863f6d28b4d7cefdf071646b09d4511188 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 17:26:24 +0800 Subject: [PATCH 1020/1191] Bug fix RE adjust_factor --- egs/librispeech/ASR/zapformer/variable_combined_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py index 5be7217d62..818b429b44 100644 --- a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -98,7 +98,7 @@ def set_epoch(self, epoch: int): assert epoch > 0 and epoch <= len(self.batches_per_epoch) # Epoch numbers are assumed to be be 1-based indexes. if epoch == self.epoch + 1 and self.batch > 0 and self.epoch > 0: self.adjust_factor = self.batch / self.batches_per_epoch[self.epoch-1] - logging.info(f"Setting self.adjust_factor = {adjust_factor} = observed/expected batches {self.batch}/{self.batches_per_epoch[self.epoch-1]} on epoch {self.epoch}") + logging.info(f"Setting self.adjust_factor = {self.adjust_factor} = observed/expected batches {self.batch}/{self.batches_per_epoch[self.epoch-1]} on epoch {self.epoch}") self.epoch = epoch self.past_batches = sum(self.batches_per_epoch[:epoch-1], start=0) From 53ad727987ef68d0cc88c12c13ea1a3d6f32fda6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 18:34:24 +0800 Subject: [PATCH 1021/1191] Invert adjust_factor --- egs/librispeech/ASR/zapformer/variable_combined_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py index 818b429b44..d825ab7d1c 100644 --- a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -97,8 +97,8 @@ def set_epoch(self, epoch: int): in the constructor. """ assert epoch > 0 and epoch <= len(self.batches_per_epoch) # Epoch numbers are assumed to be be 1-based indexes. if epoch == self.epoch + 1 and self.batch > 0 and self.epoch > 0: - self.adjust_factor = self.batch / self.batches_per_epoch[self.epoch-1] - logging.info(f"Setting self.adjust_factor = {self.adjust_factor} = observed/expected batches {self.batch}/{self.batches_per_epoch[self.epoch-1]} on epoch {self.epoch}") + self.adjust_factor = self.batches_per_epoch[self.epoch-1] / self.batch + logging.info(f"Setting self.adjust_factor = {self.adjust_factor} = expected/observed batches {self.batches_per_epoch[self.epoch-1]}/{self.batch} on epoch {self.epoch}") self.epoch = epoch self.past_batches = sum(self.batches_per_epoch[:epoch-1], start=0) From 7856b73a459b1feb6744d83c919eb75bbc4da355 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 19:11:12 +0800 Subject: [PATCH 1022/1191] Fix f-string --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b4ff98091a..7c06b05dcd 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1538,7 +1538,7 @@ def remove_short_and_long_utt(c: Cut): fix_random_seed(params.seed + epoch - 1) num_copies = get_num_copies(epoch) - logging.info("On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") + logging.info(f"On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") train_dl = asr_datamodule.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict, From fab7cde7633be617722670d5e8fd90a0271d77f0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 20:07:34 +0800 Subject: [PATCH 1023/1191] Add torch.distributed.barrier() around anything that might call fix_random_seed() --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 5 ++++- egs/librispeech/ASR/zapformer/train.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 7e91edac5d..315d88faee 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -327,6 +327,9 @@ def train_dataloaders( seed = torch.randint(0, 100000, ()).item() worker_init_fn = _SeedWorkers(seed) + # need torch.distributed.barrier() before and after anything that might call lhotse.fix_random_seed() as it fixes random seeds of all GPUs, + # not just the GPU of this process. + torch.distributed.barrier() train_dl = DataLoader( train, sampler=train_sampler, @@ -335,7 +338,7 @@ def train_dataloaders( persistent_workers=False, worker_init_fn=worker_init_fn, ) - + torch.distributed.barrier() return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7c06b05dcd..22cec60709 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1301,6 +1301,9 @@ def run(rank, world_size, args): fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) + # need torch.distributed.barrier() after fix_random_seed() as it fixes + # random seeds of all GPUs, not just the GPU of this process. + torch.distributed.barrier() setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") @@ -1391,7 +1394,7 @@ def lr_lambda(current_step): def get_num_copies(epoch): # num_epochs arg is one-based. - return max(1, params.max_copies * epoch / params.num_epochs) + return max(1, int(params.max_copies * epoch / params.num_epochs)) scheduler = InterpCosineLRScheduler(optimizer, batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) @@ -1536,6 +1539,7 @@ def remove_short_and_long_utt(c: Cut): for epoch in range(params.start_epoch, params.num_epochs + 1): fix_random_seed(params.seed + epoch - 1) + torch.distributed.barrier() num_copies = get_num_copies(epoch) logging.info(f"On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") From 3c4b5bf26147cd54627ef38dc8a5eda2be47b614 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Apr 2026 20:25:57 +0800 Subject: [PATCH 1024/1191] Re-set random seed after creating dataloader --- egs/librispeech/ASR/zapformer/train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 22cec60709..5228c25c9c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1538,6 +1538,8 @@ def remove_short_and_long_utt(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): + # fix the random seed before + torch.distributed.barrier() fix_random_seed(params.seed + epoch - 1) torch.distributed.barrier() @@ -1550,6 +1552,11 @@ def remove_short_and_long_utt(c: Cut): ) sampler_state_dict=None train_dl.sampler.set_epoch(epoch - 1) + # Re-do fixing the random seed because I believe in asr_datamodule.train_dataloaders(), fix_random_seed() + # may get called from an arbitrary worker and affect the seed of *all* the GPUs. + torch.distributed.barrier() + fix_random_seed(params.seed + epoch - 1) + torch.distributed.barrier() if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) From b91c25593dfca5d7cc2a40b6951f08ca1ca48051 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 7 Apr 2026 10:22:23 +0800 Subject: [PATCH 1025/1191] fix streaming decoding --- .../ASR/zapformer/decode_stream.py | 147 ++++++++++++++ .../ASR/zapformer/streaming_decode.py | 59 ++++-- egs/librispeech/ASR/zapformer/zapformer.py | 179 ++++++++++++++---- 3 files changed, 338 insertions(+), 47 deletions(-) create mode 100644 egs/librispeech/ASR/zapformer/decode_stream.py diff --git a/egs/librispeech/ASR/zapformer/decode_stream.py b/egs/librispeech/ASR/zapformer/decode_stream.py new file mode 100644 index 0000000000..a1bf671bf5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode_stream.py @@ -0,0 +1,147 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, at encoder output + self.done_frames: int = 0 + + # The encoder_embed subsample features (T - 7) // 2 + self.pad_length = 7 + + if params.decoding_method == "greedy_search": + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[-1] * (params.context_size - 1) + [params.blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py index 400f7804ce..c1437297f3 100755 --- a/egs/librispeech/ASR/zapformer/streaming_decode.py +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -228,8 +228,9 @@ def get_init_states( device: torch.device = torch.device("cpu"), ) -> List[torch.Tensor]: """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*5:(i+1)*5] - is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len). + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*9:(i+1)*9] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). states[-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) states[-1] is processed_lens of shape (batch,), which records the number @@ -256,8 +257,9 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: Each element in state_list corresponding to the internal state of the zapformer model for a single utterance. For element-n, state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, - cached_norm_stats, cached_norm_len). + state_list[n][i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). state_list[n][-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) state_list[n][-1] is processed_lens of shape (batch,), which records the number @@ -267,12 +269,12 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: It is the inverse of :func:`unstack_states`. """ batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 5 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 5 + assert (len(state_list[0]) - 2) % 9 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 9 batch_states = [] for layer in range(tot_num_layers): - layer_offset = layer * 5 + layer_offset = layer * 9 # cached_key: (left_context_len, batch_size, key_dim) cached_key = torch.cat( [state_list[i][layer_offset] for i in range(batch_size)], dim=1 @@ -293,12 +295,32 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: cached_norm_len = torch.cat( [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 ) + # cached_attn_wm_sum: (1, batch_size, channels) + cached_attn_wm_sum = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=1 + ) + # cached_attn_wm_num_frames: (batch_size,) + cached_attn_wm_num_frames = torch.cat( + [state_list[i][layer_offset + 6] for i in range(batch_size)], dim=0 + ) + # cached_conv_wm_sum: (1, batch_size, channels) + cached_conv_wm_sum = torch.cat( + [state_list[i][layer_offset + 7] for i in range(batch_size)], dim=1 + ) + # cached_conv_wm_num_frames: (batch_size,) + cached_conv_wm_num_frames = torch.cat( + [state_list[i][layer_offset + 8] for i in range(batch_size)], dim=0 + ) batch_states += [ cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, ] cached_embed_left_pad = torch.cat( @@ -324,8 +346,8 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: state_list: A list of list. Each element in state_list corresponding to the internal state of the zapformer model for a single utterance. """ - assert (len(batch_states) - 2) % 5 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 5 + assert (len(batch_states) - 2) % 9 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 9 processed_lens = batch_states[-1] batch_size = processed_lens.shape[0] @@ -333,16 +355,25 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: state_list = [[] for _ in range(batch_size)] for layer in range(tot_num_layers): - layer_offset = layer * 5 + layer_offset = layer * 9 # chunk dim=1 for attention maps cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) cached_value_list = batch_states[layer_offset + 1].chunk(chunks=batch_size, dim=1) - + # chunk dim=0 for conv and norm stats cached_conv_list = batch_states[layer_offset + 2].chunk(chunks=batch_size, dim=0) cached_norm_stats_list = batch_states[layer_offset + 3].chunk(chunks=batch_size, dim=0) cached_norm_len_list = batch_states[layer_offset + 4].chunk(chunks=batch_size, dim=0) - + + # chunk dim=1 for attn wm sum + cached_attn_wm_sum_list = batch_states[layer_offset + 5].chunk(chunks=batch_size, dim=1) + # chunk dim=0 for attn wm num frames + cached_attn_wm_num_frames_list = batch_states[layer_offset + 6].chunk(chunks=batch_size, dim=0) + # chunk dim=1 for conv wm sum + cached_conv_wm_sum_list = batch_states[layer_offset + 7].chunk(chunks=batch_size, dim=1) + # chunk dim=0 for conv wm num frames + cached_conv_wm_num_frames_list = batch_states[layer_offset + 8].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): state_list[i] += [ cached_key_list[i], @@ -350,6 +381,10 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv_list[i], cached_norm_stats_list[i], cached_norm_len_list[i], + cached_attn_wm_sum_list[i], + cached_attn_wm_num_frames_list[i], + cached_conv_wm_sum_list[i], + cached_conv_wm_num_frames_list[i], ] cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index a5ad2df8a8..dc3f9206ef 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -337,8 +337,9 @@ def streaming_forward( A tensor of shape (batch_size,) containing the number of frames in `x` before padding. caches: list of cached tensors of all encoder layers. For layer-i, - caches[i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, - cached_norm_stats, cached_norm_len). + caches[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. @@ -374,7 +375,7 @@ def streaming_forward( x = downsample_by(x, ds) # Slice out the specific caches for the current module - module_caches = caches[layer_offset * 5 : (layer_offset + num_layers) * 5] + module_caches = caches[layer_offset * 9 : (layer_offset + num_layers) * 9] x, new_module_caches = module.streaming_forward( src=x, @@ -413,8 +414,9 @@ def get_init_caches( ) -> List[Tensor]: """Get initial caches. - A list of cached tensors of all encoder layers. For layer-i, caches[i*5:(i+1)*5] - is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len). + A list of cached tensors of all encoder layers. For layer-i, caches[i*9:(i+1)*9] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). """ caches = [] for i, module in enumerate(self.encoders): @@ -435,12 +437,22 @@ def get_init_caches( cached_norm_stats = cached_norm_stats.to(device) cached_norm_len = cached_norm_len.to(device) + attn_value_dim = self.value_head_dim[i] * num_heads + cached_attn_wm_sum = torch.zeros(1, batch_size, attn_value_dim, device=device) + cached_attn_wm_num_frames = torch.zeros(batch_size, dtype=torch.int64, device=device) + cached_conv_wm_sum = torch.zeros(1, batch_size, embed_dim, device=device) + cached_conv_wm_num_frames = torch.zeros(batch_size, dtype=torch.int64, device=device) + caches.extend([ cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, ]) return caches @@ -619,9 +631,13 @@ def streaming_forward( cached_conv: Tensor, cached_norm_stats: Tensor, cached_norm_len: Tensor, + cached_attn_wm_sum: Tensor, + cached_attn_wm_num_frames: Tensor, + cached_conv_wm_sum: Tensor, + cached_conv_wm_num_frames: Tensor, left_context_len: int, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Pass the input through the encoder layer in streaming forward mode. Args: @@ -631,6 +647,10 @@ def streaming_forward( cached_conv: cached left context for the convolution module, of shape (batch_size, channels, left_pad) cached_norm_stats: cached SequenceNorm stats, of shape (batch_size,) cached_norm_len: cached SequenceNorm length, scalar. + cached_attn_wm_sum: (1, batch, channels), cumulative sum for attention weighted_mean + cached_attn_wm_num_frames: (batch,), number of frames for attention weighted_mean + cached_conv_wm_sum: (1, batch, channels), cumulative sum for conv weighted_mean + cached_conv_wm_num_frames: (batch,), number of frames for conv weighted_mean left_context_len: number of left context frames. src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); True means masked position. May be None. @@ -642,6 +662,10 @@ def streaming_forward( - updated cached_conv - updated cached_norm_stats - updated cached_norm_len + - updated cached_attn_wm_sum + - updated cached_attn_wm_num_frames + - updated cached_conv_wm_sum + - updated cached_conv_wm_num_frames """ src_orig = src @@ -652,19 +676,23 @@ def streaming_forward( src = src + self.feed_forward1(src, src_key_padding_mask=chunk_mask) # may try changing src_pre_ff1 to src or vice versa. - self_attn_out, cached_key, cached_value = self.self_attn.streaming_forward( + self_attn_out, cached_key, cached_value, cached_attn_wm_sum, cached_attn_wm_num_frames = self.self_attn.streaming_forward( x_qkp=src_pre_ff1, x_vg=src, left_context_len=left_context_len, cached_key=cached_key, cached_value=cached_value, + cached_wm_sum=cached_attn_wm_sum, + cached_wm_num_frames=cached_attn_wm_num_frames, key_padding_mask=src_key_padding_mask, ) src = src + self_attn_out - src_conv, cached_conv = self.conv_module.streaming_forward( + src_conv, cached_conv, cached_conv_wm_sum, cached_conv_wm_num_frames = self.conv_module.streaming_forward( 3.0 * src, - cache=cached_conv, + cached_conv=cached_conv, + cached_wm_sum=cached_conv_wm_sum, + cached_wm_num_frames=cached_conv_wm_num_frames, src_key_padding_mask=chunk_mask, ) src = src + src_conv @@ -689,6 +717,10 @@ def streaming_forward( cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, ) @@ -808,8 +840,9 @@ def streaming_forward( Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). caches: list of cached tensors of N encoder layers. For layer-i, - caches[i*5:(i+1)*5] is (cached_key, cached_value, cached_conv, - cached_norm_stats, cached_norm_len). + caches[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). left_context_len: Number of left context frames. src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); True means masked position. @@ -825,7 +858,7 @@ def streaming_forward( src = self.proj(src) num_layers = len(self.layers) - assert len(caches) == num_layers * 5 + assert len(caches) == num_layers * 9 residual_scale = self.residual_scales[0] input_scale = self.input_scale @@ -841,7 +874,11 @@ def streaming_forward( cached_conv, cached_norm_stats, cached_norm_len, - ) = caches[i * 5 : (i + 1) * 5] + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, + ) = caches[i * 9 : (i + 1) * 9] ( src, @@ -850,6 +887,10 @@ def streaming_forward( new_cached_conv, new_cached_norm_stats, new_cached_norm_len, + new_cached_attn_wm_sum, + new_cached_attn_wm_num_frames, + new_cached_conv_wm_sum, + new_cached_conv_wm_num_frames, ) = mod.streaming_forward( src, cached_key=cached_key, @@ -857,6 +898,10 @@ def streaming_forward( cached_conv=cached_conv, cached_norm_stats=cached_norm_stats, cached_norm_len=cached_norm_len, + cached_attn_wm_sum=cached_attn_wm_sum, + cached_attn_wm_num_frames=cached_attn_wm_num_frames, + cached_conv_wm_sum=cached_conv_wm_sum, + cached_conv_wm_num_frames=cached_conv_wm_num_frames, left_context_len=left_context_len, src_key_padding_mask=src_key_padding_mask, ) @@ -871,6 +916,10 @@ def streaming_forward( new_cached_conv, new_cached_norm_stats, new_cached_norm_len, + new_cached_attn_wm_sum, + new_cached_attn_wm_num_frames, + new_cached_conv_wm_sum, + new_cached_conv_wm_num_frames, ]) offset = src_with_bypass @@ -1080,6 +1129,8 @@ def streaming_forward( left_context_len: int, cached_key: Tensor, cached_value: Tensor, + cached_wm_sum: Tensor, + cached_wm_num_frames: Tensor, key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: r""" @@ -1091,6 +1142,8 @@ def streaming_forward( left_context_len: length of the cached left context. cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim). cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim). + cached_wm_sum: (1, batch, channels), cumulative sum for weighted_mean + cached_wm_num_frames: (batch,), number of frames seen so far key_padding_mask: a bool tensor of shape (batch_size, left_context_len + seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. @@ -1098,6 +1151,8 @@ def streaming_forward( - attention output, of shape (seq_len, batch_size, embed_dim) - updated cached_key, of shape (left_context_len, batch_size, key_dim) - updated cached_value, of shape (left_context_len, batch_size, value_dim) + - Updated cached_wm_sum (1, batch, channels) + - Updated cached_wm_num_frames (batch,) """ query_head_dim = self.query_head_dim num_heads = self.num_heads @@ -1141,7 +1196,16 @@ def streaming_forward( attn_weights = attn_scores.softmax(dim=-1) - v, g = self.vg_in_proj(x_vg).chunk(2, dim=-1) + vg = self.vg_in_proj(x_vg) + N = vg.shape[-1] // 3 + v = vg[..., :N] + g = vg[..., N:] + g_in, g_out = g.chunk(2, dim=-1) + v = v * self.sigmoid_in(g_in) + + wm, cached_wm_sum, cached_wm_num_frames = self.weighted_mean.streaming_forward( + v, cached_wm_sum, cached_wm_num_frames + ) # append the cached value to the current value, and update the cache assert cached_value.shape[0] == left_context_len, (cached_value.shape, left_context_len) @@ -1163,10 +1227,11 @@ def streaming_forward( ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - v = v * self.sigmoid(g) + v = v + wm + v = v * self.sigmoid_out(g_out) v = self.out_proj(v) - return v, cached_key, cached_value + return v, cached_key, cached_value, cached_wm_sum, cached_wm_num_frames def _print_attn_entropy(self, attn_weights: Tensor): # attn_weights: (num_heads, batch_size, seq_len, seq_len) @@ -1684,7 +1749,7 @@ def forward(self, if self.causal: num_frames = torch.arange(1, T + 1, device=x.device) x_cumsum = torch.cumsum(x, dim=0) - return x_cumsum * num_frames[:, None, None] * self.weights + return x_cumsum / num_frames[:, None, None] * self.weights # assume x already masked, if mask is in use. @@ -1700,6 +1765,42 @@ def forward(self, return x.mean(dim=0) * (T / num_frames) * self.weights else: return x.mean(dim=0) * self.weights + + def streaming_forward( + self, + x: Tensor, + cached_sum: Tensor, + cached_num_frames: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Streaming forward for causal weighted mean. + + Args: + x: (time, batch, channel), the current chunk + cached_sum: (1, batch, channel), cumulative sum from previous chunks + cached_num_frames: (batch,), number of frames seen so far + + Returns: + - output: (time, batch, channel) + - new_cached_sum: (1, batch, channel) + - new_cached_num_frames: (batch,) + """ + T = x.shape[0] + # cumsum within this chunk, then add the historical sum + x_cumsum = torch.cumsum(x, dim=0) + cached_sum # (T, batch, channel) + + # num_frames for each position in this chunk: (T, batch) + num_frames = cached_num_frames.unsqueeze(0) + torch.arange( + 1, T + 1, device=x.device + ).unsqueeze(1) # (T, batch) + + output = x_cumsum / num_frames.unsqueeze(-1) * self.weights + + new_cached_sum = x_cumsum[-1:, :, :] # (1, batch, channel) + new_cached_num_frames = cached_num_frames + T # (batch,) + + return output, new_cached_sum, new_cached_num_frames + class BasisConv(nn.Module): def __init__(self, @@ -1884,25 +1985,28 @@ def forward( def streaming_forward( self, x: Tensor, - cache: Tensor, + cached_conv: Tensor, + cached_wm_sum: Tensor, + cached_wm_num_frames: Tensor, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Compute convolution module in streaming mode. Args: x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv, of shape + cached_conv: cached left context for depthwise_conv, of shape (#batch, channels, left_pad) + cached_wm_sum: (1, batch, channels), cumulative sum for weighted_mean + cached_wm_num_frames: (batch,), number of frames seen so far src_key_padding_mask: the mask for the src keys per batch (optional): (batch, #time), contains True in masked positions. Returns: - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) + - Updated cached_conv (#batch, channels, left_pad) + - Updated cached_wm_sum (1, batch, channels) + - Updated cached_wm_num_frames (batch,) """ - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) - x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) x, s, y = x.chunk(3, dim=2) @@ -1910,25 +2014,30 @@ def streaming_forward( y = self.sigmoid2(y) x = x * s - x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + wm, cached_wm_sum, cached_wm_num_frames = self.weighted_mean.streaming_forward( + x, cached_wm_sum, cached_wm_num_frames + ) + + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) x_shape = x.shape - assert cache.shape[-1] == self.left_pad, (cache.shape[-1], self.left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -self.left_pad:] + assert cached_conv.shape[-1] == self.left_pad, (cached_conv.shape[-1], self.left_pad) + x = torch.cat([cached_conv, x], dim=2) + cached_conv = x[..., -self.left_pad:] x = self.depthwise_conv(x) assert x.shape == x_shape, (x.shape, x_shape) x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) + x = x + wm x = x * y x = self.out_proj(x) # (time, batch, channels) - return x, cache + return x, cached_conv, cached_wm_sum, cached_wm_num_frames def _test_zapformer_main(causal: bool = False): @@ -2100,7 +2209,7 @@ def rms(a): logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_basis_conv() - _test_zapformer_main(False) - _test_zapformer_main(True) + # _test_basis_conv() + # _test_zapformer_main(False) + # _test_zapformer_main(True) _test_zapformer_streaming() From 5583735b7c47471e9095ef102c5b835c9256f8fe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Apr 2026 12:39:04 +0800 Subject: [PATCH 1026/1191] Make LR schedule linearly decreasing rather than InterpCosine --- egs/librispeech/ASR/zapformer/train.py | 7 +++--- .../zapformer/variable_combined_scheduler.py | 22 ++++++++++++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5228c25c9c..7fb91ec488 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -85,7 +85,7 @@ from variable_combined_scheduler import VariableCombinedLRScheduler try: - from variable_combined_scheduler import InterpCosineLRScheduler + from variable_combined_scheduler import LinearLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -1395,8 +1395,9 @@ def lr_lambda(current_step): def get_num_copies(epoch): # num_epochs arg is one-based. return max(1, int(params.max_copies * epoch / params.num_epochs)) - scheduler = InterpCosineLRScheduler(optimizer, - batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) + # this LinearLRScheduler inherits from VariableCombinedLRScheduler. + scheduler = LinearLRScheduler(optimizer, + batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py index d825ab7d1c..b9affde921 100644 --- a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -147,7 +147,7 @@ def __init__(self, """ This cosine LR scheduler is halfway between the conventional cosine LR scheduler that takes the cosine from 0 to pi, and one that takes the cosine from 0 to pi/2. - It inherits from CombinedLRScheduler (see its documentation + It inherits from VariableCombinedLRScheduler (see its documentation to understand general aspects of usage). """ self.min_factor = min_factor @@ -161,3 +161,23 @@ def get_lr(self): factor = 0.5 * (factor + factor ** 2) factor = self.min_factor + factor * (1. - self.min_factor) return [x * factor for x in self.base_lrs] + + +class LinearLRScheduler(VariableCombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.05, + **kwargs): + """ + This LR scheduler decreases linearly from 1 to min_factor. + It inherits from VariableCombinedLRScheduler (see its documentation + to understand general aspects of usage). + """ + self.min_factor = min_factor + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + factor = 1.0 - progress # linearly decreasing + factor = self.min_factor + factor * (1. - self.min_factor) # apply min_factor via interpolation + return [x * factor for x in self.base_lrs] From 526ce4a1daf80f2a2a8f026ab0d9902425a5bba1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Apr 2026 14:21:44 +0800 Subject: [PATCH 1027/1191] Introduce min-copies; use no-speed-perturb copy of libr train data. --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 2 +- egs/librispeech/ASR/zapformer/train.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 315d88faee..7a21574d85 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -450,7 +450,7 @@ def train_all_shuf_cuts(self) -> CutSet: train-clean-360 and train-other-500 cuts" ) return load_manifest_lazy( - self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-all-shuf-nosp.jsonl.gz" ) def dev_clean_2_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7fb91ec488..5058daa88e 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -390,7 +390,14 @@ def get_parser(): "--max-copies", type=int, default=8, - help="The num_copies to use in the dataloader on the last epoch (it rises linearly)" + help="The num_copies to use in the dataloader on the last epoch (it rises linearly from --min-copies)" + ) + + parser.add_argument( + "--min-copies", + type=int, + default=1, + help="The num_copies to use in the dataloader on the first epoch (it rises linearly to --max-copies)" ) parser.add_argument( @@ -1394,7 +1401,7 @@ def lr_lambda(current_step): def get_num_copies(epoch): # num_epochs arg is one-based. - return max(1, int(params.max_copies * epoch / params.num_epochs)) + return params.min_copies + int((params.max_copies - params.min_copies) * epoch / params.num_epochs) # this LinearLRScheduler inherits from VariableCombinedLRScheduler. scheduler = LinearLRScheduler(optimizer, batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) From 44401bed59de961fbc5d4e5064dd7aedbd063201 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Apr 2026 15:00:53 +0800 Subject: [PATCH 1028/1191] Revert optimizer schedule to status in 2285. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5058daa88e..410ae3e7e8 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -85,7 +85,7 @@ from variable_combined_scheduler import VariableCombinedLRScheduler try: - from variable_combined_scheduler import LinearLRScheduler + from variable_combined_scheduler import LinearLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR From aaae98ebd89fd58e84e1dcc43d79844daeb799f3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Apr 2026 15:39:58 +0800 Subject: [PATCH 1029/1191] Cosmetic fix. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 410ae3e7e8..5058daa88e 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -85,7 +85,7 @@ from variable_combined_scheduler import VariableCombinedLRScheduler try: - from variable_combined_scheduler import LinearLRScheduler + from variable_combined_scheduler import LinearLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR From 7a3daf6011310db38c57248e670599e190860e4c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 11:06:41 +0800 Subject: [PATCH 1030/1191] Make num-copies rise exponentially with epoch rather than linearly. --- .../ASR/zapformer/asr_datamodule.py | 6 ++--- egs/librispeech/ASR/zapformer/train.py | 26 ++++++++++++------- icefall/utils.py | 6 +++++ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 7a21574d85..eff2c637ca 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -50,7 +50,7 @@ from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader -from icefall.utils import str2bool +from icefall.utils import str2bool, dist_barrier class _SeedWorkers: @@ -329,7 +329,7 @@ def train_dataloaders( # need torch.distributed.barrier() before and after anything that might call lhotse.fix_random_seed() as it fixes random seeds of all GPUs, # not just the GPU of this process. - torch.distributed.barrier() + dist_barrier() train_dl = DataLoader( train, sampler=train_sampler, @@ -338,7 +338,7 @@ def train_dataloaders( persistent_workers=False, worker_init_fn=worker_init_fn, ) - torch.distributed.barrier() + dist_barrier() return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 5058daa88e..9a17f11e42 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -115,6 +115,7 @@ setup_logger, str2bool, time_warp, + dist_barrier, ) @@ -389,15 +390,15 @@ def get_parser(): parser.add_argument( "--max-copies", type=int, - default=8, - help="The num_copies to use in the dataloader on the last epoch (it rises linearly from --min-copies)" + default=16, + help="The num_copies to use in the dataloader on the last epoch (it rises geometrically from --min-copies)" ) parser.add_argument( "--min-copies", type=int, default=1, - help="The num_copies to use in the dataloader on the first epoch (it rises linearly to --max-copies)" + help="The num_copies to use in the dataloader on the first epoch (it rises geometrically to --max-copies)" ) parser.add_argument( @@ -1310,7 +1311,7 @@ def run(rank, world_size, args): setup_dist(rank, world_size, params.master_port) # need torch.distributed.barrier() after fix_random_seed() as it fixes # random seeds of all GPUs, not just the GPU of this process. - torch.distributed.barrier() + dist_barrier() setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") @@ -1401,10 +1402,15 @@ def lr_lambda(current_step): def get_num_copies(epoch): # num_epochs arg is one-based. - return params.min_copies + int((params.max_copies - params.min_copies) * epoch / params.num_epochs) + # "progress" is progress between 0 and 1. subtract 0.99999 rather than 1 to avoid nan if --epochs=1. + progress = (epoch - 0.99999) / (params.num_epochs - 0.99999) + return int(params.min_copies * (params.max_copies / params.min_copies) ** progress) + + batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)] + logging.info(f"Tot real epochs = {sum(batches_per_epoch) / params.batches_per_epoch}") # this LinearLRScheduler inherits from VariableCombinedLRScheduler. scheduler = LinearLRScheduler(optimizer, - batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)]) + batches_per_epoch=batches_per_epoch) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1547,9 +1553,9 @@ def remove_short_and_long_utt(c: Cut): for epoch in range(params.start_epoch, params.num_epochs + 1): # fix the random seed before - torch.distributed.barrier() + dist_barrier() fix_random_seed(params.seed + epoch - 1) - torch.distributed.barrier() + dist_barrier() num_copies = get_num_copies(epoch) logging.info(f"On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") @@ -1562,9 +1568,9 @@ def remove_short_and_long_utt(c: Cut): train_dl.sampler.set_epoch(epoch - 1) # Re-do fixing the random seed because I believe in asr_datamodule.train_dataloaders(), fix_random_seed() # may get called from an arbitrary worker and affect the seed of *all* the GPUs. - torch.distributed.barrier() + dist_barrier() fix_random_seed(params.seed + epoch - 1) - torch.distributed.barrier() + dist_barrier() if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) diff --git a/icefall/utils.py b/icefall/utils.py index ca682a0326..c5404a20cb 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -159,6 +159,12 @@ def str2bool(v): raise argparse.ArgumentTypeError("Boolean value expected.") +def dist_barrier() -> None: + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + if world_size > 1: + dist.barrier() + def setup_logger( log_filename: Pathlike, log_level: str = "info", From bc2e4e88057fb1bc7c01945779e169b36bda4a55 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 12:30:18 +0800 Subject: [PATCH 1031/1191] Make LR depend on epoch not batch, as in 2284, and use linear decay with minimum of 0.05. --- .../ASR/zapformer/combined_scheduler.py | 14 ++++++------- egs/librispeech/ASR/zapformer/train.py | 21 +++++++++++-------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 5962a4ae0a..793e5fe8f9 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -182,7 +182,7 @@ class InterpCosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, min_factor: float = 0.05, - **kwargs): + **kwargs): # takes also batches_per_epoch and num_epochs args. """ This cosine LR scheduler is halfway between the conventional cosine LR scheduler that takes the cosine from 0 to pi, and one that takes the cosine from 0 to pi/2. @@ -205,18 +205,16 @@ def get_lr(self): class LinearLRScheduler(CombinedLRScheduler): def __init__(self, *args, - const_fraction: float = 0.2, # fraction of schedule for which we stay at 1.0 - min_factor: float = 0.1, - **kwargs): + min_factor: float = 0.0, + **kwargs): # takes also batches_per_epoch and num_epochs args. super().__init__(*args, **kwargs) - self.const_fraction = const_fraction self.min_factor = min_factor def get_lr(self): progress = self.get_progress() # initially: factor is constant at 1.0 until progress==self.const_fraction, then decays to 0 # at the end. - factor = (1.0 if progress <= self.const_fraction else (1.0 - progress) / (1. - self.const_fraction)) - # then, modify for self.min_factor - factor = max(factor, self.min_factor) + factor = 1.0 - progress + min_factor = self.min_factor + factor = min_factor + (1.0 - self.min_factor) * factor return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 9a17f11e42..dfa496170a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -83,9 +83,9 @@ pass -from variable_combined_scheduler import VariableCombinedLRScheduler +from combined_scheduler import CombinedLRScheduler try: - from variable_combined_scheduler import LinearLRScheduler + from combined_scheduler import LinearLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -796,7 +796,7 @@ def load_checkpoint_if_available( model: nn.Module, model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[VariableCombinedLRScheduler] = None, + scheduler: Optional[CombinedLRScheduler] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -862,7 +862,7 @@ def save_checkpoint( model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[VariableCombinedLRScheduler] = None, + scheduler: Optional[CombinedLRScheduler] = None, sampler: Optional[CutSampler] = None, scaler: Optional[GradScaler] = None, rank: int = 0, @@ -1078,7 +1078,7 @@ def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, - scheduler: VariableCombinedLRScheduler, + scheduler: CombinedLRScheduler, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -1406,11 +1406,14 @@ def get_num_copies(epoch): progress = (epoch - 0.99999) / (params.num_epochs - 0.99999) return int(params.min_copies * (params.max_copies / params.min_copies) ** progress) - batches_per_epoch=[params.batches_per_epoch * get_num_copies(i) for i in range(1, params.num_epochs+1)] - logging.info(f"Tot real epochs = {sum(batches_per_epoch) / params.batches_per_epoch}") - # this LinearLRScheduler inherits from VariableCombinedLRScheduler. + logging.info(f"Tot real epochs = {sum(get_num_copies(i) for i in range(1, params.num_epochs+1))}") + + # this LinearLRScheduler inherits from CombinedLRScheduler. progress decays + # in a way that's linear (actually, affine) with epoch rather than progress in batches. scheduler = LinearLRScheduler(optimizer, - batches_per_epoch=batches_per_epoch) + min_factor=0.05, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 061afac0fa645c421695c5d9c13a39e86446e2bf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 12:58:47 +0800 Subject: [PATCH 1032/1191] Use try-except for importing dist_barrier. --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 8 +++++++- egs/librispeech/ASR/zapformer/train.py | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index eff2c637ca..277949701c 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -43,6 +43,12 @@ except: pass +try: + from icefall.utils import dist_barrier +except: + pass + + from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, OnTheFlyFeatures, @@ -50,7 +56,7 @@ from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader -from icefall.utils import str2bool, dist_barrier +from icefall.utils import str2bool class _SeedWorkers: diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index dfa496170a..0143d42f24 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -115,8 +115,11 @@ setup_logger, str2bool, time_warp, - dist_barrier, ) +try: + from icefall.utils import dist_barrier +except: + pass def get_adjusted_batch_count(params: AttributeDict) -> float: From 93c0efedaed7573f44b9c29c849dda8774ab3d21 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 13:07:53 +0800 Subject: [PATCH 1033/1191] Reduce min_factor of LinearLRScheduler from 0.05 to 0.025 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 0143d42f24..785ea17ace 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1414,7 +1414,7 @@ def get_num_copies(epoch): # this LinearLRScheduler inherits from CombinedLRScheduler. progress decays # in a way that's linear (actually, affine) with epoch rather than progress in batches. scheduler = LinearLRScheduler(optimizer, - min_factor=0.05, + min_factor=0.025, batches_per_epoch=params.batches_per_epoch, num_epochs=params.num_epochs) From da056e3e1c8dd74968720cf68499151aa5e99f24 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 22:09:31 +0800 Subject: [PATCH 1034/1191] Replace LinearLRScheduler with HalfCosineLRScheduler --- .../ASR/zapformer/combined_scheduler.py | 22 +++++++++++++++++++ egs/librispeech/ASR/zapformer/train.py | 12 +++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 793e5fe8f9..6a26758897 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -202,6 +202,28 @@ def get_lr(self): return [x * factor for x in self.base_lrs] +class HalfCosineLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.05, + **kwargs): # takes also batches_per_epoch and num_epochs args. + """ + This cosine LR scheduler is the cosine from 0 to pi/2, with no offset of 1. + It inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). + """ + self.min_factor = min_factor + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + factor = math.cos((math.pi / 2) * progress) + # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate + # between the two. + factor = self.min_factor + factor * (1. - self.min_factor) + return [x * factor for x in self.base_lrs] + + class LinearLRScheduler(CombinedLRScheduler): def __init__(self, *args, diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 785ea17ace..dd818ea628 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -85,7 +85,7 @@ from combined_scheduler import CombinedLRScheduler try: - from combined_scheduler import LinearLRScheduler + from combined_scheduler import HalfCosineLRScheduler except: pass from torch.optim.lr_scheduler import LambdaLR @@ -1411,12 +1411,12 @@ def get_num_copies(epoch): logging.info(f"Tot real epochs = {sum(get_num_copies(i) for i in range(1, params.num_epochs+1))}") - # this LinearLRScheduler inherits from CombinedLRScheduler. progress decays + # this HalfCosineLRScheduler inherits from CombinedLRScheduler. progress decays # in a way that's linear (actually, affine) with epoch rather than progress in batches. - scheduler = LinearLRScheduler(optimizer, - min_factor=0.025, - batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs) + scheduler = HalfCosineLRScheduler(optimizer, + min_factor=0.025, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 538d5890305d95e242bb4f08897e99f8b521f142 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 22:54:57 +0800 Subject: [PATCH 1035/1191] Use VariableCombinedLRScheduler of linear-decay type; make num-copies rise linearly with step, not epoch; have user provide --num-real-epochs not --num-epochs. --- egs/librispeech/ASR/zapformer/train.py | 61 ++++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index dd818ea628..0db0da6baf 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -83,11 +83,15 @@ pass -from combined_scheduler import CombinedLRScheduler +from variable_combined_scheduler import VariableCombinedLRScheduler try: - from combined_scheduler import HalfCosineLRScheduler + from variable_combined_scheduler import LinearLRScheduler + LRSchedulerType = VariableCombinedLRSchedule except: pass + +SchedulerType = "VariableCombinedLRScheduler" + from torch.optim.lr_scheduler import LambdaLR from subsampling import Conv2dSubsampling from torch import Tensor @@ -384,24 +388,24 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", + "--num-real-epochs", type=int, default=30, - help="Number of epochs to train.", + help="Number of epochs to train, including number of copies; num-epochs will be <= this.", ) parser.add_argument( "--max-copies", type=int, default=16, - help="The num_copies to use in the dataloader on the last epoch (it rises geometrically from --min-copies)" + help="The num_copies to use in the dataloader on the last epoch (it rises linearly with step count from --min-copies)" ) parser.add_argument( "--min-copies", type=int, default=1, - help="The num_copies to use in the dataloader on the first epoch (it rises geometrically to --max-copies)" + help="The num_copies to use in the dataloader on the first epoch (it rises linearly with step count to --max-copies)" ) parser.add_argument( @@ -799,7 +803,7 @@ def load_checkpoint_if_available( model: nn.Module, model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[CombinedLRScheduler] = None, + scheduler: Optional[SchedulerType] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -865,7 +869,7 @@ def save_checkpoint( model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[CombinedLRScheduler] = None, + scheduler: Optional[SchedulerType] = None, sampler: Optional[CutSampler] = None, scaler: Optional[GradScaler] = None, rank: int = 0, @@ -1081,7 +1085,7 @@ def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, - scheduler: CombinedLRScheduler, + scheduler: SchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -1395,28 +1399,27 @@ def run(rank, world_size, args): beta1=0.995, ) - # hardcode batches per epoch for now. - total_steps = 4550 * params.num_epochs - def lr_lambda(current_step): - # Cosine annealing - progress = current_step / total_steps - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) - - def get_num_copies(epoch): - # num_epochs arg is one-based. - # "progress" is progress between 0 and 1. subtract 0.99999 rather than 1 to avoid nan if --epochs=1. - progress = (epoch - 0.99999) / (params.num_epochs - 0.99999) - return int(params.min_copies * (params.max_copies / params.min_copies) ** progress) + if True: + # Work out copies_per_epoch + copies_per_epoch = [ ] + cur_real_epochs = 0 + for n in range(params.min_copies, params.max_copies + 1): + progress = (n + 1 - params.min_copies) / (params.max_copies + 1 - params.min_copies) + target_real_epochs = int(0.5 + progress * params.num_real_epochs) # + 0.5 to round up. + while cur_real_epochs < target_real_epochs: + copies_per_epoch.append(n) + cur_real_epochs += n - logging.info(f"Tot real epochs = {sum(get_num_copies(i) for i in range(1, params.num_epochs+1))}") + num_epochs = len(copies_per_epoch) + logging.info(f"Num epochs = {len(copies_per_epoch)}, num-real-epochs={sum(copies_per_epoch)} vs target {params.num_real_epochs}") + logging.info(f"Copies per epoch: {copies_per_epoch}") - # this HalfCosineLRScheduler inherits from CombinedLRScheduler. progress decays + # this LinearLRScheduler inherits from VariableCombinedLRScheduler. progress decays # in a way that's linear (actually, affine) with epoch rather than progress in batches. - scheduler = HalfCosineLRScheduler(optimizer, - min_factor=0.025, - batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs) + scheduler = LinearLRScheduler(optimizer, + min_factor=0.025, + batches_per_epoch=[params.batches_per_epoch * n for n in copies_per_epoch]) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1557,13 +1560,13 @@ def remove_short_and_long_utt(c: Cut): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs + 1): + for epoch in range(params.start_epoch, num_epochs + 1): # fix the random seed before dist_barrier() fix_random_seed(params.seed + epoch - 1) dist_barrier() - num_copies = get_num_copies(epoch) + num_copies = copies_per_epoch[epoch - 1] logging.info(f"On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") train_dl = asr_datamodule.train_dataloaders( train_cuts, From 140f27542963305f61577dd8c9dc473812b49dad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Apr 2026 23:20:09 +0800 Subject: [PATCH 1036/1191] Change LinearLRScheduler to InterpCosineLRScheduler, still with min_factor=0.025 --- egs/librispeech/ASR/zapformer/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 0db0da6baf..88205745a7 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -85,7 +85,7 @@ from variable_combined_scheduler import VariableCombinedLRScheduler try: - from variable_combined_scheduler import LinearLRScheduler + from variable_combined_scheduler import InterpCosineLRScheduler LRSchedulerType = VariableCombinedLRSchedule except: pass @@ -1415,11 +1415,11 @@ def run(rank, world_size, args): logging.info(f"Num epochs = {len(copies_per_epoch)}, num-real-epochs={sum(copies_per_epoch)} vs target {params.num_real_epochs}") logging.info(f"Copies per epoch: {copies_per_epoch}") - # this LinearLRScheduler inherits from VariableCombinedLRScheduler. progress decays + # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. progress decays # in a way that's linear (actually, affine) with epoch rather than progress in batches. - scheduler = LinearLRScheduler(optimizer, - min_factor=0.025, - batches_per_epoch=[params.batches_per_epoch * n for n in copies_per_epoch]) + scheduler = InterpCosineLRScheduler(optimizer, + min_factor=0.025, + batches_per_epoch=[params.batches_per_epoch * n for n in copies_per_epoch]) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") From 5c817c1167c161d1c6a27283c24ddfd739877793 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Apr 2026 00:16:43 +0800 Subject: [PATCH 1037/1191] squared_scale=0.75 in InterpCosineLRScheduler, make its final linear phase have twice smaller linear slope. --- egs/librispeech/ASR/zapformer/train.py | 3 +++ egs/librispeech/ASR/zapformer/variable_combined_scheduler.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 88205745a7..1462a4a2d1 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1417,8 +1417,11 @@ def run(rank, world_size, args): # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. progress decays # in a way that's linear (actually, affine) with epoch rather than progress in batches. + # squared_scale=0.75 takes us a bit closer to the traditional cosine LR scheduler that + # starts and ends constant. scheduler = InterpCosineLRScheduler(optimizer, min_factor=0.025, + squared_scale=0.75, batches_per_epoch=[params.batches_per_epoch * n for n in copies_per_epoch]) if checkpoints and "optimizer" in checkpoints: diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py index b9affde921..3ae6fb4a5f 100644 --- a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -138,11 +138,11 @@ def print_lr(self, is_verbose, group, lr): - class InterpCosineLRScheduler(VariableCombinedLRScheduler): def __init__(self, *args, min_factor: float = 0.05, + squared_scale: float = 0.5, **kwargs): """ This cosine LR scheduler is halfway between the conventional cosine LR scheduler @@ -151,6 +151,7 @@ def __init__(self, to understand general aspects of usage). """ self.min_factor = min_factor + self.squared_scale = squared_scale super().__init__(*args, **kwargs) def get_lr(self): @@ -158,7 +159,7 @@ def get_lr(self): factor = math.cos((math.pi / 2) * progress) # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate # between the two. - factor = 0.5 * (factor + factor ** 2) + factor = (1. - self.squared_scale) * factor + self.squared_scale * factor ** 2 factor = self.min_factor + factor * (1. - self.min_factor) return [x * factor for x in self.base_lrs] From 93c6405d67c643ae96d1a54560fca848eef44442 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Apr 2026 16:04:38 +0800 Subject: [PATCH 1038/1191] Make InterpCosineLRScheduler more general to include linear function; half half-linear, half-conventional-cosine --- egs/librispeech/ASR/zapformer/train.py | 9 +++--- .../zapformer/variable_combined_scheduler.py | 31 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 1462a4a2d1..3d2b8ee8d7 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1415,13 +1415,12 @@ def run(rank, world_size, args): logging.info(f"Num epochs = {len(copies_per_epoch)}, num-real-epochs={sum(copies_per_epoch)} vs target {params.num_real_epochs}") logging.info(f"Copies per epoch: {copies_per_epoch}") - # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. progress decays - # in a way that's linear (actually, affine) with epoch rather than progress in batches. - # squared_scale=0.75 takes us a bit closer to the traditional cosine LR scheduler that - # starts and ends constant. + # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. + # this configuration is halfway between a linear function (1 to 0) and the conventional + # cosine LR scheduler. It decays to a minimum of 0.025. scheduler = InterpCosineLRScheduler(optimizer, min_factor=0.025, - squared_scale=0.75, + linear_scale=0.5, batches_per_epoch=[params.batches_per_epoch * n for n in copies_per_epoch]) if checkpoints and "optimizer" in checkpoints: diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py index 3ae6fb4a5f..b3a9fb262f 100644 --- a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -141,25 +141,34 @@ def print_lr(self, is_verbose, group, lr): class InterpCosineLRScheduler(VariableCombinedLRScheduler): def __init__(self, *args, - min_factor: float = 0.05, - squared_scale: float = 0.5, + min_factor: float = 0.0, + half_cosine_scale: float = 0.0, + linear_scale: float = 0.0, **kwargs): """ - This cosine LR scheduler is halfway between the conventional cosine LR scheduler - that takes the cosine from 0 to pi, and one that takes the cosine from 0 to pi/2. - It inherits from VariableCombinedLRScheduler (see its documentation - to understand general aspects of usage). + This cosine LR scheduler encompasses the conventional cosine LR scheduler + that takes the cosine from 0 to pi (shifted to 0..1), the half-cosine LR + scheduler that takes the cosine from 0 to pi, and the linear LR scheduler + that takes the linear function from 1 to 0. """ self.min_factor = min_factor - self.squared_scale = squared_scale + self.half_cosine_scale = half_cosine_scale + self.linear_scale = linear_scale super().__init__(*args, **kwargs) def get_lr(self): progress = self.get_progress() - factor = math.cos((math.pi / 2) * progress) - # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate - # between the two. - factor = (1. - self.squared_scale) * factor + self.squared_scale * factor ** 2 + half_cos = math.cos((math.pi / 2) * progress) + cos = half_cos ** 2 + linear = 1. - progress + + linear_scale = self.linear_scale + half_cosine_scale = self.half_cosine_scale + cosine_scale = 1. - self.half_cosine_scale - linear_scale + assert cosine_scale >= 0.0 + + factor = linear_scale * linear + half_cosine_scale * half_cos + cosine_scale * cos + # apply min_factor via interpolation factor = self.min_factor + factor * (1. - self.min_factor) return [x * factor for x in self.base_lrs] From 182752b09d8d362467c4470d2bbd36123e7a5901 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Apr 2026 16:06:01 +0800 Subject: [PATCH 1039/1191] Use the speed-perturb, not the nosp, version of the librispeech data. --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 277949701c..439da5e4d4 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -456,7 +456,7 @@ def train_all_shuf_cuts(self) -> CutSet: train-clean-360 and train-other-500 cuts" ) return load_manifest_lazy( - self.manifest_dir / "librispeech_cuts_train-all-shuf-nosp.jsonl.gz" + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" ) def dev_clean_2_cuts(self) -> CutSet: From 566d591eaeab6e275d7e7d584969953be46be128 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Apr 2026 16:31:19 +0800 Subject: [PATCH 1040/1191] Change code to get copies_per_epoch to be reversed, to minimize rounding errors --- egs/librispeech/ASR/zapformer/train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 3d2b8ee8d7..742f9ee3ab 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1404,12 +1404,16 @@ def run(rank, world_size, args): # Work out copies_per_epoch copies_per_epoch = [ ] cur_real_epochs = 0 - for n in range(params.min_copies, params.max_copies + 1): - progress = (n + 1 - params.min_copies) / (params.max_copies + 1 - params.min_copies) - target_real_epochs = int(0.5 + progress * params.num_real_epochs) # + 0.5 to round up. + progress_increment = 1.0 / (params.max_copies + 1 - params.min_copies) + cur_progress = 0.0 + # go in backwards order to minimize rounding errors. + for n in reversed(range(params.min_copies, params.max_copies + 1)): + cur_progress += progress_increment + target_real_epochs = int(0.5 + cur_progress * params.num_real_epochs) # + 0.5 to round up. while cur_real_epochs < target_real_epochs: copies_per_epoch.append(n) cur_real_epochs += n + copies_per_epoch = list(reversed(copies_per_epoch)) num_epochs = len(copies_per_epoch) logging.info(f"Num epochs = {len(copies_per_epoch)}, num-real-epochs={sum(copies_per_epoch)} vs target {params.num_real_epochs}") From c08b95b6077eaf874457afd62631da0f5c4f62fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 11:13:40 +0800 Subject: [PATCH 1041/1191] Add debugging print statements to show model param values and random values in augmentation. --- .../ASR/zapformer/alternating_spec_augment.py | 3 +++ egs/librispeech/ASR/zapformer/train.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index b05da7b005..7b37c502aa 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -1,3 +1,4 @@ +import logging import random from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union @@ -168,6 +169,8 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # "rlength" means relative length of each mask, i.e. relative to seq_len. the # lengths in mask_lengths are normalized lengths. mask_rlengths = torch.rand(B, M, device=device) * (max_mask_fraction / num_masks) + if (seq_len + batch_size) % 100 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. + logging.info(f"mask_rlengths: {mask_rlengths.flatten()}") mask_tot_rlen = mask_rlengths.sum(dim=1, keepdim=True) # (batch_size, 1) # padding_tot_rlen is the total relative length of the padding segmnts. We clamp to min=0.25 diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 88205745a7..6890b63036 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1080,6 +1080,14 @@ def compute_validation_loss( return tot_loss +def show_model_params(model: nn.Module): + with torch.no_grad(): + params = [ ] + for p in model.parameters(): + params.append(p.flatten()[-1:]) + all_last_elems = torch.cat(params) + logging.info(f"All last elems of parameters = {all_last_elems}, sum = {all_last_elems.sum()}") + def train_one_epoch( params: AttributeDict, @@ -1155,6 +1163,9 @@ def save_bad_model(suffix: str = ""): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx % 200 == 0: + show_model_params(model) + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) From 91ac7128898ed5c2c4aa7a172a05b19f21aa7200 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 11:42:23 +0800 Subject: [PATCH 1042/1191] Change debug statements for balance --- egs/librispeech/ASR/zapformer/alternating_spec_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 7b37c502aa..77bab3dadf 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -169,8 +169,8 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # "rlength" means relative length of each mask, i.e. relative to seq_len. the # lengths in mask_lengths are normalized lengths. mask_rlengths = torch.rand(B, M, device=device) * (max_mask_fraction / num_masks) - if (seq_len + batch_size) % 100 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. - logging.info(f"mask_rlengths: {mask_rlengths.flatten()}") + if (seq_len + batch_size) % 10 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. + logging.info(f"mask_rlengths: {mask_rlengths.flatten()[:10]}") mask_tot_rlen = mask_rlengths.sum(dim=1, keepdim=True) # (batch_size, 1) # padding_tot_rlen is the total relative length of the padding segmnts. We clamp to min=0.25 From 2a8498f6a59e534c3ae89c0d4270e52803e4ed63 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 12:52:40 +0800 Subject: [PATCH 1043/1191] Update debug statement --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6890b63036..f1d5366299 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1086,7 +1086,7 @@ def show_model_params(model: nn.Module): for p in model.parameters(): params.append(p.flatten()[-1:]) all_last_elems = torch.cat(params) - logging.info(f"All last elems of parameters = {all_last_elems}, sum = {all_last_elems.sum()}") + logging.info(f"All last elems of parameters sum = {all_last_elems.sum()}") def train_one_epoch( From 99b041239ab5fdd5c99df8bd47d217a254191208 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 13:54:49 +0800 Subject: [PATCH 1044/1191] Desynchronize the torch rng's; and rely on only the torch rng in time_warp. --- egs/librispeech/ASR/zapformer/train.py | 18 ++++++++-- icefall/utils.py | 49 +++++++++++++++++++++++--- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 88205745a7..08359b9cbe 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1330,6 +1330,7 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): + torch.cuda.set_device(rank) device = torch.device("cuda", rank) logging.info(f"Device: {device}") @@ -1561,7 +1562,8 @@ def remove_short_and_long_utt(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, num_epochs + 1): - # fix the random seed before + # fix all random seeds before starting the dataloaders, as they require + # all seeds to be synchronized. dist_barrier() fix_random_seed(params.seed + epoch - 1) dist_barrier() @@ -1576,13 +1578,25 @@ def remove_short_and_long_utt(c: Cut): sampler_state_dict=None train_dl.sampler.set_epoch(epoch - 1) # Re-do fixing the random seed because I believe in asr_datamodule.train_dataloaders(), fix_random_seed() - # may get called from an arbitrary worker and affect the seed of *all* the GPUs. + # may get called from an arbitrary worker with a worker-specific offset, and affect the seed of *all* the GPUs. dist_barrier() fix_random_seed(params.seed + epoch - 1) + # fix_random_seed may get called from an arbitrary worker with a + # worker-specific offset, and affect the seed of *all* the GPUs, with uncertain timing. so we + # call it again to make sure the seed is deterministic. dist_barrier() + # now desynchronize the torch RNGs for CPU and GPU by calling rand() a + # different number of times, so the augmentation isn't the same across + # ranks. It's very difficult to do this with torch.manual_seed() + # because it has no way to set the RNG for just the CPU. + for _ in range(rank): + torch.randn(100) + torch.randn(100, device=device) + if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + tb_writer.add_scalar("train/num_copies", num_copies, params.batch_idx_train) params.cur_epoch = epoch scheduler.set_epoch(epoch) diff --git a/icefall/utils.py b/icefall/utils.py index c5404a20cb..15961e89da 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -43,7 +43,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from packaging import version from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials @@ -2425,6 +2424,45 @@ def num_tokens( return num_tokens +def time_warp_impl(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + # modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py#L338C1-L369C1 + # to use torch rng rather than the numpy one, this has to do with which rngs + # are synchronized and which are not. (we keep the numpy and python rng's synchronized + # for the sake of lhotse's sampler code, where they need to be synchronized to avoid data + # overlap). + + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = torch.randint(factor + 1, t - factor, ()).item() + warped = torch.randint(center - factor, center + factor + 1) + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) + + # Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py def time_warp( features: torch.Tensor, @@ -2443,10 +2481,13 @@ def time_warp( len(features.shape) == 3 ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" features = features.clone() + + # we use torch.rand(1).item() instead of random.random() because for lhotse reasons we keep the + # python RNG synchronized across ranks, but we keep the torch RNG desynchronized. if supervision_segments is None and feature_lens is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - if random.random() > p: + if torch.rand(1).item() > p: # Randomly choose whether this transform is applied continue features[sequence_idx] = time_warp_impl( @@ -2456,7 +2497,7 @@ def time_warp( assert feature_lens is None # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: - if random.random() > p: + if torch.rand(1).item() > p: # Randomly choose whether this transform is applied continue end_frame = start_frame + num_frames @@ -2466,7 +2507,7 @@ def time_warp( else: for sequence_idx, num_frames in enumerate(feature_lens): - if random.random() > p: + if torch.rand(1).item() > p: # Randomly choose whether this transform is applied continue features[sequence_idx, :num_frames] = time_warp_impl( From b7e1a9520697ead3cc002bb31449ff69c966a137 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 13:56:41 +0800 Subject: [PATCH 1045/1191] Add comment --- egs/librispeech/ASR/zapformer/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 08359b9cbe..da53d8182a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1588,9 +1588,11 @@ def remove_short_and_long_utt(c: Cut): # now desynchronize the torch RNGs for CPU and GPU by calling rand() a # different number of times, so the augmentation isn't the same across - # ranks. It's very difficult to do this with torch.manual_seed() - # because it has no way to set the RNG for just the CPU. - for _ in range(rank): + # ranks within a batch. It's very difficult to do this with + # torch.manual_seed() because it has no way to set the RNG for just the + # CPU. This is not 100% ideal as we'll still possibly repeat after a delay, but + # it's simple to implement. + for _ in range(rank * 4): torch.randn(100) torch.randn(100, device=device) From 9c4e419c1fd1ee7f9164da59f87ca12833a7277b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 14:08:50 +0800 Subject: [PATCH 1046/1191] Bug fix --- icefall/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/utils.py b/icefall/utils.py index 15961e89da..6f3bdd17e4 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2444,7 +2444,7 @@ def time_warp_impl(features: torch.Tensor, factor: int) -> torch.Tensor: if t - factor <= factor + 1: return features center = torch.randint(factor + 1, t - factor, ()).item() - warped = torch.randint(center - factor, center + factor + 1) + warped = torch.randint(center - factor, center + factor + 1, ()).item() if warped == center: return features features = features.unsqueeze(0).unsqueeze(0) From e11db7febbf3205cf7762a0a7ea63b11d88e2449 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 14:28:21 +0800 Subject: [PATCH 1047/1191] Set cuda seed in a more complete way --- egs/librispeech/ASR/zapformer/train.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 672a38ddf9..fc7c715056 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1597,15 +1597,14 @@ def remove_short_and_long_utt(c: Cut): # call it again to make sure the seed is deterministic. dist_barrier() - # now desynchronize the torch RNGs for CPU and GPU by calling rand() a - # different number of times, so the augmentation isn't the same across - # ranks within a batch. It's very difficult to do this with - # torch.manual_seed() because it has no way to set the RNG for just the - # CPU. This is not 100% ideal as we'll still possibly repeat after a delay, but - # it's simple to implement. - for _ in range(rank * 4): - torch.randn(100) - torch.randn(100, device=device) + with torch.cuda.device(rank): + # set CUDA seed for "my GPU" in a rank-dependent way. assume the only multi-node training we'll + # do is with cuda so do not worry about CPU seed. in fact, we do also rely on the + # torch CPU random number generator for data augmentation- see time_warp()- + # but this gets naturally desynchronized quite soon because it's called in a loop + # that depends on the number of elements in a batch. + torch.cuda.manual_seed(params.seed + epoch - 1 + 1000 * rank) + if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) From 62637765e2ec15529abfb38c292d69b64005988e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 15:13:22 +0800 Subject: [PATCH 1048/1191] Remove debugging statements. --- egs/librispeech/ASR/zapformer/alternating_spec_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 77bab3dadf..927e780261 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -169,8 +169,8 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # "rlength" means relative length of each mask, i.e. relative to seq_len. the # lengths in mask_lengths are normalized lengths. mask_rlengths = torch.rand(B, M, device=device) * (max_mask_fraction / num_masks) - if (seq_len + batch_size) % 10 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. - logging.info(f"mask_rlengths: {mask_rlengths.flatten()[:10]}") + #if (seq_len + batch_size) % 10 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. + # logging.info(f"mask_rlengths: {mask_rlengths.flatten()[:10]}") mask_tot_rlen = mask_rlengths.sum(dim=1, keepdim=True) # (batch_size, 1) # padding_tot_rlen is the total relative length of the padding segmnts. We clamp to min=0.25 From f981352aac2c683f69c0db7ade8307b65a18c1d4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 16:31:35 +0800 Subject: [PATCH 1049/1191] Do random seeding and initialization of sampler differently; do not make that seed the torch seed. --- .../ASR/zapformer/asr_datamodule.py | 32 ++++++++++--------- egs/librispeech/ASR/zapformer/train.py | 31 ++++++------------ 2 files changed, 26 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 277949701c..3ee0110397 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -22,6 +22,8 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import random # to set its random seed +import numpy # to set its random seed import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy @@ -53,7 +55,6 @@ AudioSamples, OnTheFlyFeatures, ) -from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -64,11 +65,9 @@ def __init__(self, seed: int): self.seed = seed def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriSpeechAsrDataModule: - pass # only left here so other branches can run in the same directory. TODO: remove. + random_seed = self.seed + 9999 * worker_id + random.seed(random_seed) + np.random.seed(random_seed) class AsrDataModule: """ @@ -243,6 +242,11 @@ def train_dataloaders( cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, num_copies: int = 1, + seed: int = 100, # lets us specify different seed if we create data loader on different epochs. + # note: the seed has to be the same across ranks, because the samplers need to be kept in sync + # so we can divide up the data accurately. + rank: int = 0, # the torch. distributed rank, affects the seed used for + ) -> DataLoader: """ Args: @@ -314,6 +318,7 @@ def train_dataloaders( buffer_size=self.args.num_buckets * 2000, shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, + seed=seed, ) else: logging.info(f"Using SimpleCutSampler, num_copies={num_copies}") @@ -321,6 +326,7 @@ def train_dataloaders( cuts_train, max_duration=self.args.max_duration / num_copies, shuffle=self.args.shuffle, + seed=seed, ) logging.info("About to create train dataloader") @@ -328,14 +334,11 @@ def train_dataloaders( logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - # need torch.distributed.barrier() before and after anything that might call lhotse.fix_random_seed() as it fixes random seeds of all GPUs, - # not just the GPU of this process. - dist_barrier() + # the data-loader workers do not have to be synchronized across the process-group, + # we can give them rank-dependent seeds. (There may not actually be any randomization + # at this level in this zapformer recipe though, we do SpecAug in the main process + # and I think the musan-related stuff happens in the sampler. + worker_init_fn = _SeedWorkers(seed + 4321 * rank) train_dl = DataLoader( train, sampler=train_sampler, @@ -344,7 +347,6 @@ def train_dataloaders( persistent_workers=False, worker_init_fn=worker_init_fn, ) - dist_barrier() return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fc7c715056..8519f23ad8 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1080,15 +1080,6 @@ def compute_validation_loss( return tot_loss -def show_model_params(model: nn.Module): - with torch.no_grad(): - params = [ ] - for p in model.parameters(): - params.append(p.flatten()[-1:]) - all_last_elems = torch.cat(params) - logging.info(f"All last elems of parameters sum = {all_last_elems.sum()}") - - def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], @@ -1163,9 +1154,6 @@ def save_bad_model(suffix: str = ""): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx % 200 == 0: - show_model_params(model) - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -1455,6 +1443,7 @@ def run(rank, world_size, args): register_inf_check_hooks(model) asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(args.manifest_dir) gigaspeech = GigaSpeech(args.manifest_dir) # gigaspeech will only be used if the --use-giga=True option is set commonvoice = CommonVoice(args.manifest_dir) # commonvoice will only be used if the --use-cv=True option is set @@ -1558,6 +1547,8 @@ def remove_short_and_long_utt(c: Cut): train_dl = asr_datamodule.train_dataloaders( train_cuts, num_copies=1, + seed=params.seed, + rank=rank, ) scan_pessimistic_batches_for_oom( model=model, @@ -1585,17 +1576,13 @@ def remove_short_and_long_utt(c: Cut): train_cuts, sampler_state_dict=sampler_state_dict, num_copies=num_copies, + seed=params.seed + 500 * epoch, + rank=rank, ) sampler_state_dict=None - train_dl.sampler.set_epoch(epoch - 1) - # Re-do fixing the random seed because I believe in asr_datamodule.train_dataloaders(), fix_random_seed() - # may get called from an arbitrary worker with a worker-specific offset, and affect the seed of *all* the GPUs. - dist_barrier() - fix_random_seed(params.seed + epoch - 1) - # fix_random_seed may get called from an arbitrary worker with a - # worker-specific offset, and affect the seed of *all* the GPUs, with uncertain timing. so we - # call it again to make sure the seed is deterministic. - dist_barrier() + # we don't do : + # train_dl.sampler.set_epoch(epoch) + # because we just created the sampler and its seed already depends on the epoch. with torch.cuda.device(rank): # set CUDA seed for "my GPU" in a rank-dependent way. assume the only multi-node training we'll @@ -1603,7 +1590,7 @@ def remove_short_and_long_utt(c: Cut): # torch CPU random number generator for data augmentation- see time_warp()- # but this gets naturally desynchronized quite soon because it's called in a loop # that depends on the number of elements in a batch. - torch.cuda.manual_seed(params.seed + epoch - 1 + 1000 * rank) + torch.cuda.manual_seed(params.seed + 50 * epoch + 512 * rank) if tb_writer is not None: From de8b0afd4c8d6cae61fec6b00a26f63f2151f427 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 16:46:29 +0800 Subject: [PATCH 1050/1191] Bug fix importing numpy as np --- egs/librispeech/ASR/zapformer/asr_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index 3ee0110397..8a78c57a41 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Any, Dict, Optional import random # to set its random seed -import numpy # to set its random seed +import numpy as np # to set its random seed import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy From 7746732e5db1cc5da27af4a752bf6b9901a00ecb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 17:36:56 +0800 Subject: [PATCH 1051/1191] Make limit_param_value non-randomized; also remove un-used python files. --- egs/librispeech/ASR/zapformer/scaling.py | 1295 ----------------- .../ASR/zapformer/scaling_converter.py | 99 -- .../ASR/zapformer/zapformer_utils.py | 6 +- 3 files changed, 2 insertions(+), 1398 deletions(-) delete mode 100644 egs/librispeech/ASR/zapformer/scaling.py delete mode 100644 egs/librispeech/ASR/zapformer/scaling_converter.py diff --git a/egs/librispeech/ASR/zapformer/scaling.py b/egs/librispeech/ASR/zapformer/scaling.py deleted file mode 100644 index 06cb538627..0000000000 --- a/egs/librispeech/ASR/zapformer/scaling.py +++ /dev/null @@ -1,1295 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import math -import copy -import random -from typing import Optional, Tuple, Union, Any - -import k2 -import torch -import torch.nn as nn -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd - - - - -class FloatLike: # TODO: remove. this is to solve problems with multiple jobs running. - pass -class ScheduledFloat: # TODO: remove. this is to solve problems with multiple jobs running. - pass -class SimpleOrthogonalLinear: # TODO: remove. this is to solve problems with multiple jobs running. - pass -class PiecewiseLinear: # TODO: remove. this is to solve problems with multiple jobs running. - pass -class CosineSimilarityLoss: # TODO: remove. this is to solve problems with multiple jobs running. - pass -class PredictLoss: # TODO: remove. this is to solve problems with multiple jobs running. - pass -get_max_similarity = None # TODO: remove. this is to solve problems with multiple jobs running. - - - -def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: - max_value = torch.max(x, y) - diff = torch.abs(x - y) - return max_value + torch.log1p(torch.exp(-diff)) - - -# RuntimeError: Exporting the operator logaddexp to ONNX opset version -# 14 is not supported. Please feel free to request support or submit -# a pull request on PyTorch GitHub. -# -# The following function is to solve the above error when exporting -# models to ONNX via torch.jit.trace() -def logaddexp(x: Tensor, y: Tensor) -> Tensor: - # Caution(fangjun): Put torch.jit.is_scripting() before - # torch.onnx.is_in_onnx_export(); - # otherwise, it will cause errors for torch.jit.script(). - # - # torch.logaddexp() works for both torch.jit.script() and - # torch.jit.trace() but it causes errors for ONNX export. - # - if torch.jit.is_scripting(): - # Note: We cannot use torch.jit.is_tracing() here as it also - # matches torch.onnx.export(). - return torch.logaddexp(x, y) - elif torch.onnx.is_in_onnx_export(): - return logaddexp_onnx(x, y) - else: - # for torch.jit.trace() - return torch.logaddexp(x, y) - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.get_autocast_gpu_dtype()) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.amp.autocast('cuda', enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim=dim) - - return SoftmaxFunction.apply(x, dim) - - - -# all arg tensors are scalars. -def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): - stats = (x ** 2).mean(dim=2, keepdim=True) - T = x.shape[0] # time - if mask is None: - stats = stats.sum(dim=0) - lengths = T - else: - mask = (~mask).to(torch.float).t().unsqueeze(-1) - stats = stats * mask - stats = stats.sum(dim=0) - lengths = mask.sum(dim=0) - - scales = (lengths / stats).sqrt() - assert scales.shape == (x.shape[1], 1) - return x * ((scale * scales) + offset) - -# all arg tensors are scalars. -# mask only used in non-causal mode; ballast_rms and ballast_frames only used in causal mode. -def _causal_sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor): - stats = (x ** 2).mean(dim=2, keepdim=True) - - # no need for mask in causal mode. - # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so - # make absolutely sure using abs(). - ballast_frames = 100.0 * ballast_frames.abs() - ballast = ballast_frames * (ballast_rms ** 2) - T = x.shape[0] # time - - stats = stats.cumsum(dim=0) + ballast - lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] - - scales = (lengths / stats).sqrt() - assert scales.shape == (T, x.shape[1], 1) - return x * ((scale * scales) + offset) - - -# all arg tensors are scalars -def _causal_sequence_norm_streaming( - x: Tensor, - offset: Tensor, - scale: Tensor, - cached_stats_sum: Tensor, - cached_len: Tensor, -) -> Tuple[Tensor, Tensor, Tensor]: - """Streaming inference forward for _sequence_norm. We assume that ballast_frames and ballast_rms - are already included in cached_stats_sum and cached_len. - - Args: - x: (seq_len, batch_size, channels) - offset: scalar - scale: scalar - cached_stats_sum: (batch_size,) - cached_len: (batch_size,) - - Returns: - - normalized x, (seq_len, batch_size, channels) - - updated cached_stats_sum, (batch_size,) - - updated cached_len, (batch_size,) - """ - stats = (x ** 2).mean(dim=2, keepdim=True) # (seq_len, batch_size, 1) - - T = x.shape[0] # time - - stats = stats.cumsum(dim=0) + cached_stats_sum.unsqueeze(-1) - lengths = cached_len[:, None] + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] - - # update cached_stats_sum and cached_len for the next chunk - cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) - cached_len = cached_len + T - - scales = (lengths / stats).sqrt() # (T, batch_size, 1) - assert scales.shape == (T, x.shape[1], 1) - return x * ((scale * scales) + offset), cached_stats_sum, cached_len - - -class CausalSequenceNormFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - offset: Tensor, - scale: Tensor, - ballast_rms: Tensor, - ballast_frames: Tensor, - ) -> Tensor: - ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) - - return _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) - - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors - - - with torch.amp.autocast('cuda', enabled=False): - x = x.to(torch.float32).detach().requires_grad_() - offset = offset.to(torch.float32).detach().requires_grad_() - scale = scale.to(torch.float32).detach().requires_grad_() - ballast_rms = ballast_rms.to(torch.float32).detach().requires_grad_() - ballast_frames = ballast_frames.to(torch.float32).detach().requires_grad_() - - with torch.enable_grad(): - ans = _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) - ans.backward(gradient=ans_grad.to(torch.float32)) - - def c(x): - # this is to replace infinities that might be thrown up - # in autocast mode: scalars will tend to have larger grads than non-scalars, - # this code is to reduce the probabilities that any infinities could crash the - # training (it may still happen if the world-size is so large that these - # infinities get added together though). - return x.clamp_(min=-30000.0, max=30000.0) - - return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad) - -class SequenceNormFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - offset: Tensor, - scale: Tensor, - mask: Optional[Tensor], - ) -> Tensor: - ctx.save_for_backward(x, offset, scale) - ctx.mask = mask - - return _sequence_norm(x, offset, scale, mask) - - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - x, offset, scale = ctx.saved_tensors - - with torch.amp.autocast('cuda', enabled=False): - x = x.to(torch.float32).detach().requires_grad_() - offset = offset.to(torch.float32).detach().requires_grad_() - scale = scale.to(torch.float32).detach().requires_grad_() - - with torch.enable_grad(): - ans = _sequence_norm(x, offset, scale, ctx.mask) - ans.backward(gradient=ans_grad.to(torch.float32)) - - def c(x): - # this is to replace infinities that might be thrown up - # in autocast mode: scalars will tend to have larger grads than non-scalars, - # this code is to reduce the probabilities that any infinities could crash the - # training (it may still happen if the world-size is so large that these - # infinities get added together though). - return x if x is None else x.clamp_(min=-30000.0, max=30000.0) - - return x.grad, c(offset.grad), c(scale.grad), None - - -class CausalSequenceNorm(torch.nn.Module): - """ - This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence - up to the current point as well as the channels, with some padding of the stats with "default values" - determined by ballast_frames, ballast_rms for robustness near the beginning of the sequence. - - There is also a learnable scalar scale, multiplicatively applied to the output, and a learnable - "offset" value that acts multiplicatively on the input without taking into account the rms values. - """ - def __init__( - self, - ) -> None: - super().__init__() - self.scale = nn.Parameter(torch.tensor(0.5)) - self.offset = nn.Parameter(torch.tensor(0.0001)) - - # ballast_mean: assumed rms value of ballast frames used to pad stats - self.ballast_rms = nn.Parameter(torch.tensor(0.1)) - # ballast_frames: number of ballast frames, in hundreds (will be multiplied by 100) - self.ballast_frames = nn.Parameter(torch.tensor(0.05)) # number of ballast frames, will be multiplied by 100 - self.name = None - - def forward(self, x: Tensor, _mask: Optional[Tensor] = None) -> Tensor: - # x: (seq, batch, channel) - # The mask is ignored, it is allowed only for consistency of interface with SequenceNorm. - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _causal_sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames) - - scale = limit_param_value( - self.scale, min=0.05, max=2.0, training=self.training) - - offset = limit_param_value( - self.offset, min=0.0, max=10.0, training=self.training) - - ballast_rms = limit_param_value( - self.ballast_rms, min=0.0, max=10.0, training=self.training) - - ballast_frames = limit_param_value( - self.ballast_frames, min=0.0, max=5.0, training=self.training) # max of 5.0 would be 500 frames - - ans = CausalSequenceNormFunction.apply( - x, offset, scale, ballast_rms, ballast_frames, - ) - - if random.random() < 0.002: - x_rms = (x ** 2).mean().sqrt() - ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}, ballast_rms={self.ballast_rms.item()}, ballast_frames*100={100*self.ballast_frames.item()}") - - return ans - - @torch.jit.export - def get_init_cache(self, batch_size: int): - """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. - """ - # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so - # make absolutely sure using abs(). - ballast_frames = 100.0 * self.ballast_frames.abs() - ballast = ballast_frames * (self.ballast_rms ** 2) - - cached_stats_sum = ballast.unsqueeze(0).repeat(batch_size) # (batch_size,) - cached_len = ballast_frames.unsqueeze(0).repeat(batch_size) # (batch_size,) - - return cached_stats_sum, cached_len - - def streaming_forward( - self, - x: Tensor, - cached_stats_sum: Tensor, - cached_len: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - - x, cached_stats_sum, cached_len = _causal_sequence_norm_streaming( - x, self.offset, self.scale, cached_stats_sum, cached_len) - return x, cached_stats_sum, cached_len - - -class SequenceNorm(torch.nn.Module): - """ - This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence - as well as the channels; and a padding mask is used for irregular length sequences (actually, - the mask is applied multiplicatively as well.) - - There is also a learnable scalar scale and a learnable "offset" value. - """ - def __init__( - self, - ) -> None: - super().__init__() - self.scale = nn.Parameter(torch.tensor(0.5)) - self.offset = nn.Parameter(torch.tensor(0.0001)) - self.name = None - - def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: - # x: (seq, batch, channel) - # mask: bool, (batch_size, seq_len) - # Note: mask is ignored in causal mode. - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _sequence_norm(x, self.offset, self.scale, mask) - - scale = limit_param_value( - self.scale, min=0.05, max=2.0, training=self.training) - - offset = limit_param_value( - self.offset, min=0.0, max=10.0, training=self.training) - - ans = SequenceNormFunction.apply( - x, offset, scale, mask, - ) - - if random.random() < 0.002: - x_rms = (x ** 2).mean().sqrt() - ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") - - return ans - - - -# assume layout: (time, batch, channel) -def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): - x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) - scales = scale / x_sq.sqrt() - return x * scales - - - -class RmsNormFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - eps: Tensor, - scale: Tensor, - ) -> Tensor: - ctx.save_for_backward(x, eps, scale) - return _rms_norm(x, eps, scale) - - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - x, eps, scale = ctx.saved_tensors - - with torch.amp.autocast('cuda', enabled=False): - x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) - x, eps, scale = x.detach(), eps.detach(), scale.detach() - - x.requires_grad = True - eps.requires_grad = True - scale.requires_grad = True - - with torch.enable_grad(): - ans = _rms_norm(x, eps, scale) - ans.backward(gradient=ans_grad.to(torch.float32)) - - def c(x): - # this is to replace infinities that might be thrown up - # in autocast mode. - return x.clamp_(min=-30000.0, max=30000.0) - - return x.grad, c(eps.grad), c(scale.grad) - - -class RmsNorm(torch.nn.Module): - """ - This is like RMSNorm with a trainable scale. - - """ - def __init__( - self, - ) -> None: - super(RmsNorm, self).__init__() - self.scale = nn.Parameter(torch.tensor(0.2)) # output scale - self.eps = nn.Parameter(torch.tensor(0.1)) - self.name = None - - - def forward(self, x: Tensor) -> Tensor: - # Assumes layout is (time, batch, channel) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return _rms_norm(x, self.eps, self.scale) - - scale = limit_param_value( - self.scale, min=0.05, max=1.0, training=self.training) - - eps = limit_param_value( - self.eps, min=0.0, max=10.0, training=self.training) - - ans = RmsNormFunction.apply( - x, eps, scale, - ) - - if random.random() < 0.002: - x_rms = (x ** 2).mean().sqrt() - ans_rms = (ans ** 2).mean().sqrt() - logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, eps={eps.item()}, scale={scale.item()}") - - return ans - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.01 * initial_scale, 0.01 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class OrthogonalPenaltyFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, weight: Tensor, penalty_scale: float, name: str): - ctx.save_for_backward(weight) - ctx.name = name - ctx.penalty_scale = penalty_scale - return weight - - @staticmethod - @custom_bwd - def backward(ctx, weight_grad): - weight, = ctx.saved_tensors - - if weight.requires_grad and ctx.penalty_scale != 0.0: - penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() - - with torch.enable_grad(): - weight = weight.detach() - weight.requires_grad = True - - # Compute symmetric matrix-product prod with the smallest - # dimension possible given the shape of w. This is not just for - # efficiency; if we computed it the wrong way round, the product - # would have deficient rank and could never be the identity. - if (weight.shape[0] > weight.shape[1]): - prod = torch.matmul(weight.t(), weight) - else: - prod = torch.matmul(weight, weight.t()) - - # we'll try to enforce that for any i, prod[i] is any constant times the identity. - - # in the loss-function: - # orthogonality_loss = ((prod - I) ** 2).sum(), - - # note, prod_diag shares memory with prod, this will matter later on. - (r, c) = prod.shape - (r_stride, c_stride) = prod.stride() - - def diag_inplace(z): - return torch.as_strided(z, size=(r,), stride=(r_stride+c_stride,)) - - diag_inplace(prod)[:] -= 1. - - # that loss that we want to backprop would be 0.5 * (prod ** - # 2).sum() * penalty_scale. we can backprop this without doing - # any reductions as follows: - prod.backward(gradient=prod * penalty_scale) - - - do_print = random.random() < 0.002 - if do_print: - # we print a normalized version of the loss, by dividing by the - # number of rows. - loss = (prod ** 2).mean() - logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") - - - # add the extra gradient term from the orthogonality loss. - weight_grad = weight_grad + weight.grad - return weight_grad, None, None - -class OrthogonalLinear(nn.Linear): - """ - Like nn.Linear but can enforce that the weight matrix is orthogonal; in the non-square - case this is interpreted as either M^T M == I or M M^T == I, whichever would give a smaller - dimension. - (If M is square, these definitions are equivalent and is equivalent to the normal - definition of orthogonal). - - Args: - in_channels: number of input channels - out_channels: number of output channels - lr_scale: we will scale the weight by this value before applying the orthogonal - constraint and using it; with most optimizers - this will have the effect of slowing down the learning by this factor because - the parameter value will be larger. - bias: if True, include a bias term. - penalty_scale: a scale on the penalty on non-orthogonality (this will - be multiplied by the average-absolute-value of the - backpropagated gradient). - """ - # if in_groups or out_groups are set to >1, the orthogonal constraint - # will be set per group. both of them cannot be >1. - def __init__(self, - in_channels: int, - out_channels: int, - lr_scale: float = 1.0, - bias: bool = True, - penalty_scale: float = 20.0, - ): - super().__init__(in_channels, out_channels, bias=bias) - self.name = None - self.penalty_scale = copy.deepcopy(penalty_scale) - self.lr_scale = lr_scale - - with torch.no_grad(): - self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * (1. / lr_scale) - if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.01, 0.01) - - - def forward(self, x: Tensor, transpose: bool = False): - # you can only use transpose=True if you used bias=False in initialization - weight = self.weight - lr_scale = self.lr_scale - if lr_scale != 1.0: - weight = weight * lr_scale - if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): - weight = OrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) - - if transpose: - weight = weight.t() - return torch.nn.functional.linear(x, weight, self.bias) - - -class ScaleLimiterFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, max_rms: float, aux_loss_scale: float, name: str): - ctx.save_for_backward(x) - ctx.max_rms = max_rms - ctx.aux_loss_scale = aux_loss_scale - ctx.name = name - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors - with torch.enable_grad(): - with torch.amp.autocast('cuda', enabled=False): - x = x.to(torch.float) - x = x.detach() - x.requires_grad = True - rms = (x ** 2).mean(dim=-1).sqrt() - numel = rms.numel() - - excess = (rms / ctx.max_rms - 1.).relu().mean() - - if random.random() < 0.002: - logging.info( - f"ScaleLimiter: name={ctx.name}, max_rms={ctx.max_rms}, " - f"rms={rms.mean().item()}, excess={excess.item()}, " - f"loss_scale={ctx.aux_loss_scale}" - ) - excess.backward(gradient=torch.full_like(excess, ctx.aux_loss_scale * numel)) - return x_grad + x.grad, None, None, None - - -class ScaleLimiter(torch.nn.Module): - """ - Adds a penalty in backprop if the norm of any activation vector is less than min_rms - or more than max_rms. - - Assumes channel dim is -1 and the input shape has >1 dimension. - """ - def __init__(self, max_rms: float): - super().__init__() - self.name = None - self.max_rms = max_rms - - - def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return _no_op(x) - else: - return ScaleLimiterFunction.apply(x, float(self.max_rms), - aux_loss_scale, self.name) - - -class CorrelationLimiterFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): - ctx.save_for_backward(x) - ctx.mask = mask - ctx.limit = limit - ctx.aux_loss_scale = aux_loss_scale - ctx.name = name - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 - x, = ctx.saved_tensors - mask = ctx.mask - aux_loss_scale = ctx.aux_loss_scale - (batch_size, seq_len, num_channels) = x.shape - - with torch.enable_grad(): - with torch.amp.autocast('cuda', enabled=False): - x = x.to(torch.float) - x = x.detach() - x.requires_grad = True - x_orig = x - - def norm(x: Tensor): - eps = 1.0e-20 - return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() - x = norm(x) - - if mask is not None: - mask = (~mask).to(x.dtype).unsqueeze(-1) - x = x * mask - - half_batch = batch_size // 2 - if half_batch <= 1: - # the reason we also return None if half_batch==1 is because of CR-CTC - # where they may really be duplicates - return None, None, None, None, None - - - #x = torch.cat((x, y), dim=-1) - C = x.shape[-1] # num_channels - x1, x2 = x[0::2], x[1::2] - x1 = x1.reshape(-1, C) - x2 = x2.reshape(-1, C) - - if mask is not None: - numel1 = mask[0::2].sum() - numel2 = mask[1::2].sum() - else: - numel1 = x1.shape[0] - numel2 = x2.shape[0] - - S1 = torch.matmul(x1.t(), x1) * (1. / numel1) - S2 = torch.matmul(x2.t(), x2) * (1. / numel2) - - # S1, S2: (N, N) where N = min(num_channels, max_channels) - correlation = (S1 * S2).mean() - loss = (correlation - ctx.limit).relu() - - if random.random() < 0.0001: - logging.info( - f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, loss={loss}" - ) - - loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) - - - return x_orig.grad, None, None, None, None - - -class CorrelationLimiter(torch.nn.Module): - """ - Adds a penalty in backprop if the input feature has a covariance matrix that is - too different from the identity matrix. limit=1/num_channels is the - smallest limit you can provide but the limit should be much larger than - this, like 1/sqrt(num_channels). - - Assumes input is (batch, seq, channel) - """ - def __init__(self, limit: float = 0.03): - super().__init__() - self.name = None - self.limit = limit - - - def forward(self, x: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: - # x should be: (batch, seq, channel) - # returns a scalar tensor that should be included in the loss function with: - # z = with_loss(z, ret, None) - # where z is any quantity that will be used in calculating the main loss. - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return torch.tensor(0.0, device=x.device) - else: - return CorrelationLimiterFunction.apply(x, - aux_loss_scale, - float(self.limit), - mask, - self.name) - - - - -def penalize_abs_values_gt( - x: Tensor, limit: float, penalty: float, name: str = None -) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - ctx.dtype = y.dtype - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ( - ans_grad, - torch.ones(ctx.y_shape, dtype=ctx.dtype, device=ans_grad.device), - None, - ) - - -def with_loss(x, y, name=None): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return x - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x,) = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where( - torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 - ) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - - -def limit_param_value( - x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True -): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - - - - - -def torch_compile(fn, *args, **kwargs): - if hasattr(torch, 'compile'): - fn = torch.compile(fn, *args, **kwargs, dynamic=True, options={"shape_padding": True, "force_shape_pad": True}) - return fn - -def swashl(x: Tensor) -> Tensor: - zero = torch.zeros_like(x) - return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 - -def swashr(x: Tensor) -> Tensor: - zero = torch.zeros_like(x) - return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 - - -def swashl_and_deriv(x: Tensor): - x_offset = 4. * x - 4. - denom = 1. + x_offset.exp() - inv_denom = 1. / denom # note: 1 / infinity = 0. - deriv = 0.92 - inv_denom; - log_denom = denom.log() - log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) - y = 0.25 * log_denom - 0.08 * x - 0.00875 - return y, deriv - -def swashr_and_deriv(x: Tensor): - x_offset = 4. * x - 1. - denom = 1. + x_offset.exp() - inv_denom = 1. / denom # note: 1 / infinity = 0. - deriv = 0.92 - inv_denom; - log_denom = denom.log() - log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) - y = 0.25 * log_denom - 0.08 * x - 0.07831542175 - return y, deriv - - -class SwashL(torch.nn.Module): - def __init__(self): - super().__init__() - self.func = torch_compile(swashl) - def forward(self, x: Tensor) -> Tensor: - """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 - on the input and 0.25 on the output..""" - return self.func(x) - -class SwashR(torch.nn.Module): - def __init__(self): - super().__init__() - self.func = torch_compile(swashr) - def forward(self, x: Tensor) -> Tensor: - """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 - on the input and 0.25 on the output..""" - return self.func(x) - - - -class ActivationAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - forward_func: Any, - backward_func: Any, - ): - ctx.save_for_backward(x, weight, bias) - - ctx.backward_func = backward_func - - x = forward_func(x) - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias) = saved - - y, func_deriv = ctx.backward_func(x) - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - return x_deriv, weight_deriv, bias_deriv, None, None - - - -class ActivationAndLinear(torch.nn.Module): - """ - This merges an activation function followed by a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwashL, this will be - equivalent to: - nn.Sequential(SwashL(), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwashL, SwashR. - """ - def __init__( - self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = "SwashL", - initial_scale: float = 1.0, - ): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - l = ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=initial_scale - ) - - self.weight = l.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter("bias", l.bias) - - self.activation = activation - - assert activation in ["SwashL", "SwashR"] - if activation == "SwashL": - self.forward_func = torch_compile(swashl) - self.backward_func = torch_compile(swashl_and_deriv) - else: - self.forward_func = torch_compile(swashr) - self.backward_func = torch_compile(swashr_and_deriv) - - - def forward(self, x: Tensor): - if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - x = self.forward_func(x) - return torch.nn.functional.linear(x, self.weight, self.bias) - - return ActivationAndLinearFunction.apply( - x, - self.weight, - self.bias, - self.forward_func, - self.backward_func, - ) - - -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) - - - -def _test_swashl_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwashL() - - tol = 1.0 / 255.0 - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swashr_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwashR() - - tol = 1.0 / 255.0 - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_activation_and_linear(): - in_channels = 20 - out_channels = 30 - - for bias in [True, False]: - if True: - for activation in ["SwashL", "SwashR"]: - m1 = nn.Sequential( - SwashL() if activation == "SwashL" else SwashR(), - ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=0.5 - ), - ) - m2 = ActivationAndLinear( - in_channels, - out_channels, - bias=bias, - initial_scale=0.5, - activation=activation, - ) - with torch.no_grad(): - m2.weight[:] = m1[1].weight - if bias: - m2.bias[:] = m1[1].bias - # make sure forward gives same result. - x1 = torch.randn(10, in_channels) - x1.requires_grad = True - - - x2 = x1.clone().detach() - x2.requires_grad = True - seed = 10 - torch.manual_seed(seed) - y1 = m1(x1) - y_grad = torch.randn_like(y1) - y1.backward(gradient=y_grad) - torch.manual_seed(seed) - y2 = m2(x2) - y2.backward(gradient=y_grad) - - print( - f"bias = {bias}, activation = {activation}" - ) - print("y1 = ", y1) - print("y2 = ", y2) - assert torch.allclose(y1, y2, atol=0.02) - print("grad1 = ", m1[1].weight.grad) - print("grad2 = ", m2.weight.grad) - - assert torch.allclose(m1[1].weight.grad, m2.weight.grad, atol=1.0e-05) - if bias: - assert torch.allclose(m1[1].bias.grad, m2.bias.grad, atol=1.0e-05) - print("x1.grad = ", x1.grad) - print("x2.grad = ", x2.grad) - - def isclose(a, b): - # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ( - (a**2).sum() * (b**2).sum() - ).sqrt() - - # the SwashL() implementation has a noisy gradient due to 1-byte - # storage of it. - assert isclose(x1.grad, x2.grad) - - -def _test_orthogonal_linear(): - m = OrthogonalLinear(128, 128) - m(torch.randn(30, 2, 128)) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_softmax() - _test_swashr_deriv() - _test_swashl_deriv() - _test_activation_and_linear() - _test_orthogonal_linear() diff --git a/egs/librispeech/ASR/zapformer/scaling_converter.py b/egs/librispeech/ASR/zapformer/scaling_converter.py deleted file mode 100644 index e4ee960838..0000000000 --- a/egs/librispeech/ASR/zapformer/scaling_converter.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, Whiten is replaced with an identity operator. -""" - -import copy -from typing import List - -import torch -import torch.nn as nn -from scaling import ( - SwashL, - SwashLOnnx, - SwashR, - SwashROnnx, -) -from zapformer import RelPosScores - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, - is_onnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - is_onnx: - True if we are going to export the model for ONNX. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, (Dropout3, ScaleGrad, Whiten)): - d[name] = nn.Identity() - elif is_onnx and isinstance(m, SwashR): - d[name] = SwashROnnx() - elif is_onnx and isinstance(m, SwashL): - d[name] = SwashLOnnx() - elif is_onnx and isinstance(m, RelPosScores): - # We want to recreate the positional encoding vector when - # the input changes, so we have to use torch.jit.script() - # to replace torch.jit.trace() - d[name] = torch.jit.script(m) - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/librispeech/ASR/zapformer/zapformer_utils.py b/egs/librispeech/ASR/zapformer/zapformer_utils.py index 4b8b1dc8cd..6d04f95c80 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_utils.py +++ b/egs/librispeech/ASR/zapformer/zapformer_utils.py @@ -148,14 +148,12 @@ def backward(ctx, x_grad: Tensor): def limit_param_value( - x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True + x: Tensor, min: float, max: float, training: bool = True ): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: + if training: return LimitParamValue.apply(x, min, max) else: return x From 0871a46747880240feffe9d7913a0973c914c349 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 17:48:05 +0800 Subject: [PATCH 1052/1191] Move time_warp to alternating_spec_augment.py --- .../ASR/zapformer/alternating_spec_augment.py | 79 +++++++++++++++++++ egs/librispeech/ASR/zapformer/train.py | 3 +- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 927e780261..a815ca8bf0 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -286,6 +286,85 @@ def _test_alternating_spec_augment(): + +def time_warp_impl(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + # modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py#L338C1-L369C1 + # to use torch rng rather than the numpy one, this has to do with which rngs + # are synchronized and which are not. (we keep the numpy and python rng's synchronized + # for the sake of lhotse's sampler code, where they need to be synchronized to avoid data + # overlap). + + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = torch.randint(factor + 1, t - factor, ()).item() + warped = torch.randint(center - factor, center + factor + 1, ()).item() + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +# it does not differ substantively from that; only, it accepts feature_lens rather than supervision +# segments, and uses torch as the random number generator. +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + feature_lens: Optional[torch.Tensor] = None, +): + if time_warp_factor is None or time_warp_factor < 1: + return features + assert ( + len(features.shape) == 3 + ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" + features = features.clone() + + # we use torch.rand(1).item() instead of random.random() because for lhotse reasons we keep the + # python RNG synchronized across ranks, but we keep the torch RNG desynchronized. + if feature_lens is None: + # No feature_lens - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if torch.rand(1).item() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + for sequence_idx, num_frames in enumerate(feature_lens): + if torch.rand(1).item() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx, :num_frames] = time_warp_impl( + features[sequence_idx, :num_frames], factor=time_warp_factor + ) + + return features + + # from lhotse.dataset import SpecAugment if __name__ == '__main__': diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 8519f23ad8..42a6860e56 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -110,7 +110,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from alternating_spec_augment import AlternatingSpecAugment # using this, not lhotse's version of nn.Module +from alternating_spec_augment import AlternatingSpecAugment, time_warp from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -118,7 +118,6 @@ get_parameter_groups_with_lrs, setup_logger, str2bool, - time_warp, ) try: from icefall.utils import dist_barrier From e0fd17fdbcc372b885b0d41b327b60478cba56c2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 17:49:36 +0800 Subject: [PATCH 1053/1191] Fix import --- egs/librispeech/ASR/zipformer/attention_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index bff536f90b..648be4b1e0 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -23,7 +23,7 @@ import torch import torch.nn as nn from label_smoothing import LabelSmoothingLoss -from scaling import penalize_abs_values_gt +from zapformer_utils import penalize_abs_values_gt from icefall.utils import add_eos, add_sos, make_pad_mask From 0a94ca33fcf3ce3caa863019e8cb9446a895d8f0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 21:24:41 +0800 Subject: [PATCH 1054/1191] Do not require CTC --- egs/librispeech/ASR/zapformer/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 42a6860e56..dfd5105920 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1372,8 +1372,6 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - assert params.use_ctc # for now, require CTC, we may remove this requirement later. - assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: From bc202c235614b6f54c2e0fff0be5c29097c5da06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 23:12:13 +0800 Subject: [PATCH 1055/1191] Make depthwise_conv non-central weights 10 times smaller --- egs/librispeech/ASR/zapformer/zapformer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 1dc47ca549..bf12bbf3d1 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1765,7 +1765,7 @@ def forward(self, return x.mean(dim=0) * (T / num_frames) * self.weights else: return x.mean(dim=0) * self.weights - + def streaming_forward( self, x: Tensor, @@ -1800,7 +1800,7 @@ def streaming_forward( new_cached_num_frames = cached_num_frames + T # (batch,) return output, new_cached_sum, new_cached_num_frames - + class BasisConv(nn.Module): def __init__(self, @@ -1910,7 +1910,13 @@ def __init__( bias=False, ) self.left_pad = kernel_size - 1 - self.depthwise_conv.lr_scale = 0.66 + + self.depthwise_conv.lr_scale = 0.66 # not sure whether to keep this, it wasn't very conclusive. + with torch.no_grad(): + # make the non-central convolution weights much smaller. + k2 = kernel_size // 2 + self.depthwise_conv.weight[..., :k2] *= 0.1 + self.depthwise_conv.weight[..., -k2:] *= 0.1 # add average-of-all-frames to the "convolution."; it has extra power vs the convolution # because the num frames differs between utterances. From 531917362736ccd079b1ee9a104c31fcf1a4cd3d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Apr 2026 23:13:25 +0800 Subject: [PATCH 1056/1191] use try-except when importing time_warp to avoid multi-job problem --- egs/librispeech/ASR/zapformer/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index dfd5105920..93675b07a4 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -110,7 +110,12 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from alternating_spec_augment import AlternatingSpecAugment, time_warp +from alternating_spec_augment import AlternatingSpecAugment +try: + from alternating_spec_augment import time_warp +except: + pass + from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, From a1b6764977d468d03f54dacf6f8f67cc33f2afc8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Apr 2026 22:14:06 +0800 Subject: [PATCH 1057/1191] Improve code regarding random number generators, for greater clarity and locality. Should remove synchronization of time_warp() across ranks. --- .../ASR/zapformer/alternating_spec_augment.py | 178 +++++++++++------- egs/librispeech/ASR/zapformer/train.py | 82 +++----- 2 files changed, 141 insertions(+), 119 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index a815ca8bf0..b27c0336eb 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -23,7 +23,10 @@ def __init__( num_feature_masks: int = 2, max_frame_mask_fraction: float = 0.725, # the expected temporal masked-fraction is half of this. max_frame_mask_size: float = 70, # max size in frames of temporal masks. - p=0.9, # probability of doing augmentation + p=0.9, # probability of doing core SpecAug augmentation + time_warp_p=0.9, # probability of doing time warping. + time_warp_factor=80, # as in original SpecAug paper. + seed=None, # if you leave this as none it will use random.random() ): super().__init__() assert 0 <= p <= 1 @@ -38,12 +41,50 @@ def __init__( self.max_frame_mask_size = max_frame_mask_size self.p = p + self.time_warp_p = time_warp_p + self.time_warp_factor = time_warp_factor + + self.seed = seed + self.device_to_generator = dict() + + def get_generator(self, device): + try: + return self.device_to_generator[str(device)] + except KeyError: + gen = torch.Generator(device) + gen.manual_seed(self.seed if self.seed is not None else torch.randint(0, 100000, ()).item()) + self.device_to_generator[str(device)] = gen + return gen + + def forward( + self, + features: torch.Tensor, + feature_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Do augmentaiton and return modified features. + features: (batch_size, seq_len, num_channels) + feature_lens: (batch_size,), contains sequence lengths 0 < feature_lens <= seq_len + """ + if self.time_warp_p > 0: + features = time_warp(features, + p=self.time_warp_p, + time_warp_factor=self.time_warp_factor, + feature_lens=feature_lens, + generator=self.get_generator(torch.device('cpu'))) + if self.p > 0: + features = self.forward_masking(features) + return features + + def forward_masking( self, features: torch.Tensor, ) -> torch.Tensor: """ - Computes ExpAugment for a batch of feature matrices. + Computes ExpAugment for a batch of feature matrices. Caution: for time warping + the user should call self.time_warp() separately. It's a class member for purposes + of keeping track of generators. Since the batch will usually already be padded, the user can optionally provide a ``supervision_segments`` tensor that will be used to apply SpecAugment @@ -59,10 +100,6 @@ def forward( B, T, F = features.shape features = features.clone() - - # get feature means. - kwargs = {'device': features.device} - mean = features.mean() features_unmasked = features @@ -80,7 +117,8 @@ def forward( max_mask_fraction=self.max_frame_mask_fraction, num_masks=num_masks) - features = torch.where(torch.rand(B, 1, 1, **kwargs).expand_as(features) < self.p, + generator = self.get_generator(features.device) + features = torch.where(torch.rand(B, 1, 1, device=features.device, generator=generator).expand_as(features) < self.p, features, features_unmasked) return features @@ -159,6 +197,7 @@ def _mask_on_axis(self, def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_fraction, device) -> Tuple[Tuple,Tuple]: + generator = self.get_generator(device) # we imagine there are "pairs of sequences" for historical reasons but one of each pair is not # a real sequence. B = batch_size @@ -168,7 +207,7 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # "rlength" means relative length of each mask, i.e. relative to seq_len. the # lengths in mask_lengths are normalized lengths. - mask_rlengths = torch.rand(B, M, device=device) * (max_mask_fraction / num_masks) + mask_rlengths = torch.rand(B, M, device=device, generator=generator) * (max_mask_fraction / num_masks) #if (seq_len + batch_size) % 10 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. # logging.info(f"mask_rlengths: {mask_rlengths.flatten()[:10]}") mask_tot_rlen = mask_rlengths.sum(dim=1, keepdim=True) # (batch_size, 1) @@ -188,7 +227,7 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ P = M + 1 # rpositions means positions expressed in relative length, i.e. normalized so that # seq_len is 1. - padding_rpositions = torch.rand(B, P - 1, device=device) * padding_tot_rlen + padding_rpositions = torch.rand(B, P - 1, device=device, generator=generator) * padding_tot_rlen padding_rpositions = padding_rpositions.sort(dim=1)[0] zero = torch.zeros(B, 1, device=device) padding_rpositions = torch.cat((zero, padding_rpositions, padding_tot_rlen), dim=1) @@ -217,7 +256,7 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # letting the start-position when we take alternating positions be # randomly 0 or 1 avoids any overall bias towards the start or end of # the sequence. - index = torch.randint(0, 2, (B,), device=device).unsqueeze(-1) + torch.arange(0, M, step=2, device=device) + index = torch.randint(0, 2, (B,), device=device, generator=generator).unsqueeze(-1) + torch.arange(0, M, step=2, device=device) mask_starts = torch.gather(mask_starts, dim=1, index=index) mask_ends = torch.gather(mask_ends, dim=1, index=index) @@ -238,56 +277,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]): setattr(self, name, state_dict["name"]) -def _test_alternating_spec_augment(): - for n in [ 0, 1 ]: - #device = 'cuda' - B, T, F = 301, 600, 80 - device = 'cpu' - - if n == 0: - aspec_augment = AlternatingSpecAugment() - else: - from lhotse.dataset import SpecAugment - time_mask_ratio = 3.5 - num_frame_masks = int(10 * time_mask_ratio) - max_frames_mask_fraction = 0.15 * time_mask_ratio - print( - f"num_frame_masks: {num_frame_masks}, " - f"max_frames_mask_fraction: {max_frames_mask_fraction}" - ) - spec_augment = SpecAugment( - time_warp_factor=0, # Do time warping in model.py - num_frame_masks=num_frame_masks, # default: 10 - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 - p=0.9, - ) - supervision_segments = torch.stack(( - torch.arange(B, device=device), # sequence_idx - torch.zeros(B, device=device, dtype=torch.long), # start_frame - T * torch.ones(B, device=device, dtype=torch.long) # num_frames - ), dim=-1) - aspec_augment = lambda x: spec_augment(x, supervision_segments) - - features = torch.randn(B, T, F, device=device) - lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) - #print("features=", features) - features = aspec_augment(features) - - frame_is_masked = features[:, :, 0] == features[:, :, -1] - print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) - - print("mean frame_is_masked[per-frame][::10] = ", frame_is_masked.to(torch.float).mean(dim=0)[::10]) - feature_is_masked = features[:, 0] == features[:, -1] - print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) - print("mean feature_is_masked[per-freq] = ", feature_is_masked.to(torch.float).mean(dim=0)) - - - - -def time_warp_impl(features: torch.Tensor, factor: int) -> torch.Tensor: +def time_warp_impl(features: torch.Tensor, factor: int, + generator: Optional[torch.Generator]) -> torch.Tensor: """ # modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py#L338C1-L369C1 # to use torch rng rather than the numpy one, this has to do with which rngs @@ -306,8 +297,8 @@ def time_warp_impl(features: torch.Tensor, factor: int) -> torch.Tensor: t = features.size(0) if t - factor <= factor + 1: return features - center = torch.randint(factor + 1, t - factor, ()).item() - warped = torch.randint(center - factor, center + factor + 1, ()).item() + center = torch.randint(factor + 1, t - factor, (), generator=generator).item() + warped = torch.randint(center - factor, center + factor + 1, (), generator=generator).item() if warped == center: return features features = features.unsqueeze(0).unsqueeze(0) @@ -334,6 +325,7 @@ def time_warp( p: float = 0.9, time_warp_factor: Optional[int] = 80, feature_lens: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, # generator for CPU only ): if time_warp_factor is None or time_warp_factor < 1: return features @@ -342,12 +334,12 @@ def time_warp( ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" features = features.clone() - # we use torch.rand(1).item() instead of random.random() because for lhotse reasons we keep the - # python RNG synchronized across ranks, but we keep the torch RNG desynchronized. + # we use torch.rand(1).item() instead of random.random() for easier control of generators + # that is more consistent with GPU generators. if feature_lens is None: # No feature_lens - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - if torch.rand(1).item() > p: + if torch.rand(1, generator=generator).item() > p: # Randomly choose whether this transform is applied continue features[sequence_idx] = time_warp_impl( @@ -355,16 +347,68 @@ def time_warp( ) else: for sequence_idx, num_frames in enumerate(feature_lens): - if torch.rand(1).item() > p: + if torch.rand(1, generator=generator).item() > p: # Randomly choose whether this transform is applied continue features[sequence_idx, :num_frames] = time_warp_impl( - features[sequence_idx, :num_frames], factor=time_warp_factor + features[sequence_idx, :num_frames], factor=time_warp_factor, + generator=generator, ) return features + + +def _test_alternating_spec_augment(): + for n in [ 0, 1 ]: + #device = 'cuda' + B, T, F = 301, 600, 80 + device = 'cpu' + + if n == 0: + aspec_augment = AlternatingSpecAugment(time_warp_p=0.0) + else: + from lhotse.dataset import SpecAugment + time_mask_ratio = 3.5 + num_frame_masks = int(10 * time_mask_ratio) + max_frames_mask_fraction = 0.15 * time_mask_ratio + print( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + p=0.9, + ) + supervision_segments = torch.stack(( + torch.arange(B, device=device), # sequence_idx + torch.zeros(B, device=device, dtype=torch.long), # start_frame + T * torch.ones(B, device=device, dtype=torch.long) # num_frames + ), dim=-1) + aspec_augment = lambda x: spec_augment(x, supervision_segments) + + features = torch.randn(B, T, F, device=device) + lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) + #print("features=", features) + features = aspec_augment(features) + + frame_is_masked = features[:, :, 0] == features[:, :, -1] + print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) + + print("mean frame_is_masked[per-frame][::10] = ", frame_is_masked.to(torch.float).mean(dim=0)[::10]) + feature_is_masked = features[:, 0] == features[:, -1] + print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) + print("mean feature_is_masked[per-freq] = ", feature_is_masked.to(torch.float).mean(dim=0)) + + + + # from lhotse.dataset import SpecAugment if __name__ == '__main__': diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index cd4b89ebc9..ab4c9dcd75 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -111,10 +111,6 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from alternating_spec_augment import AlternatingSpecAugment -try: - from alternating_spec_augment import time_warp -except: - pass from icefall.hooks import register_inf_check_hooks from icefall.utils import ( @@ -401,7 +397,7 @@ def get_parser(): parser.add_argument( "--max-copies", type=int, - default=16, + default=1, help="The num_copies to use in the dataloader on the last epoch (it rises linearly with step count from --min-copies)" ) @@ -919,40 +915,6 @@ def save_checkpoint( -def augmentation( - features: Tensor, - feature_lens: Tensor) -> Tensor: - """ - - Args: - features: a Tensor of shape (batch_size, seq_len, num_channels) - - Returns: - augmented_features - """ - (batch_size, seq_len, num_channels) = features.shape - - do_time_warp = True - - if do_time_warp: - with torch.amp.autocast('cuda', enabled=False): - features = time_warp( - features.to(torch.float), - time_warp_factor=80, - feature_lens=feature_lens, - ) - - # note: AlternatingSpecAugment() does *somewhat* assume that x consists of two copies of - # the same data, but practically speaking the only important use this is put - # to is that it chooses non-overlapping frequency regions to mask. it also - # chooses non-overlapping time regions to mask, but this is not so important - # since the time warping (if used) was done independently on the two copies. - spec_augment = AlternatingSpecAugment() - features = spec_augment(features) - - return features - - def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -960,6 +922,7 @@ def compute_loss( batch: dict, is_training: bool, aux_loss_scale: float = 0.0, + specaug: Optional[nn.Module] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -992,11 +955,9 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) - - if is_training: - batch_size = features.shape[0] - features = augmentation(features, feature_lens) - + if specaug is not None: + with torch.amp.autocast('cuda', enabled=False): + features = specaug(features.to(torch.float), feature_lens) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( @@ -1095,6 +1056,7 @@ def train_one_epoch( scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, + specaug: Optional[nn.Module] = None, world_size: int = 1, rank: int = 0, ) -> None: @@ -1171,6 +1133,7 @@ def save_bad_model(suffix: str = ""): sp=sp, batch=batch, is_training=True, + specaug=specaug, aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), ) # summary stats @@ -1316,7 +1279,9 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) + # synchronize seeds. important for parameter initialization to be consistent. fix_random_seed(params.seed) + if world_size > 1: setup_dist(rank, world_size, params.master_port) # need torch.distributed.barrier() after fix_random_seed() as it fixes @@ -1571,9 +1536,13 @@ def remove_short_and_long_utt(c: Cut): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) + for epoch in range(params.start_epoch, num_epochs + 1): # fix all random seeds before starting the dataloaders, as they require - # all seeds to be synchronized. + # all seeds to be synchronized, in particular for the sampler, which + # uns in the main process and relies on the currently-set random seed + # (in practice it's just the random module's + # seed and possibly the numpy seed that really matter here. dist_barrier() fix_random_seed(params.seed + epoch - 1) dist_barrier() @@ -1587,19 +1556,27 @@ def remove_short_and_long_utt(c: Cut): seed=params.seed + 500 * epoch, rank=rank, ) + sampler_state_dict=None # we don't do : # train_dl.sampler.set_epoch(epoch) # because we just created the sampler and its seed already depends on the epoch. - with torch.cuda.device(rank): - # set CUDA seed for "my GPU" in a rank-dependent way. assume the only multi-node training we'll - # do is with cuda so do not worry about CPU seed. in fact, we do also rely on the - # torch CPU random number generator for data augmentation- see time_warp()- - # but this gets naturally desynchronized quite soon because it's called in a loop - # that depends on the number of elements in a batch. - torch.cuda.manual_seed(params.seed + 50 * epoch + 512 * rank) + seed = params.seed + 50 * epoch + 512 * rank + + specaug = AlternatingSpecAugment( + seed=seed, + ) # otherwise use all default settings. + + if torch.cuda.is_available(): + with torch.cuda.device(rank): + # set CUDA seed for "my GPU" in a rank-and-epoch-dependent way. + # This is not not very important, it should just affect the + # AddNoise() module in subsampling.py + torch.cuda.manual_seed(seed) + else: + torch.manual_seed(seed) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -1619,6 +1596,7 @@ def remove_short_and_long_utt(c: Cut): valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, + specaug=specaug, world_size=world_size, rank=rank, ) From c64b00ecbbc0c8873c33bd527535a3958dcb0997 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 12 Apr 2026 13:20:46 +0800 Subject: [PATCH 1058/1191] Remove correlation limiter. --- egs/librispeech/ASR/zapformer/zapformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index bf12bbf3d1..4566a95fd6 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -546,8 +546,8 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) - power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) - self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) + #power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + #self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn = MultiheadRelPosGatedSelfAttention( embed_dim, @@ -595,9 +595,9 @@ def forward( """ src_orig = src - src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), - 2. * aux_loss_scale, mask=src_key_padding_mask), - None) + #src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + # 2. * aux_loss_scale, mask=src_key_padding_mask), + # None) src_pre_ff1 = src From 1d06aa86405efdcd34a7ff532cc9759d7e1e44f6 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 13 Apr 2026 11:21:33 +0800 Subject: [PATCH 1059/1191] Refactor AngularFreqBasis for caching and reuse --- egs/librispeech/ASR/zapformer/zapformer.py | 92 +++++++++++++++++----- 1 file changed, 72 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index bf12bbf3d1..e305d91372 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -101,6 +101,7 @@ def __init__( num_heads: Union[int, Tuple[int]] = 8, feedforward_multiple: Union[int, Tuple[int]] = 4, conv_params: Union[int, Tuple[int]] = 31, + num_freqs: int = 64, causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], @@ -130,6 +131,7 @@ def _to_tuple(x): self.num_heads = num_heads = _to_tuple(num_heads) feedforward_multiple = _to_tuple(feedforward_multiple) self.conv_params = conv_params = _to_tuple(conv_params) + self.num_freqs = num_freqs self.causal = causal self.chunk_size = (chunk_size,) if isinstance(chunk_size, int) else chunk_size @@ -156,6 +158,7 @@ def _to_tuple(x): pos_head_dim=pos_head_dim[i], feedforward_multiple=feedforward_multiple[i], conv_params=conv_params[i], + num_freqs=num_freqs, causal=causal, ) @@ -171,6 +174,12 @@ def _to_tuple(x): self.encoders = nn.ModuleList(encoders) + # Share a single AngularFreqBasis instance across all layers within each encoder stack + for encoder in self.encoders: + shared_basis = AngularFreqBasis(num_freqs=num_freqs) + for layer in encoder.layers: + layer.self_attn.rel_pos.angular_freq_basis = shared_basis + self.out_norm = RmsNorm() @@ -538,6 +547,7 @@ def __init__( pos_head_dim: int, feedforward_multiple: int, conv_params: int, + num_freqs: int = 64, causal: bool = False, ) -> None: super(ZapformerEncoderLayer, self).__init__() @@ -555,6 +565,7 @@ def __init__( query_head_dim=query_head_dim, value_head_dim=value_head_dim, pos_head_dim=pos_head_dim, + num_freqs=num_freqs, causal=causal, ) @@ -949,7 +960,8 @@ def __init__( query_head_dim: int, pos_head_dim: int , value_head_dim: int, - causal: bool, + num_freqs: int = 64, + causal: bool = False, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -972,7 +984,7 @@ def __init__( bias=True, initial_scale=0.125 * query_head_dim**-0.25 ) - self.rel_pos = RelPosScores(num_heads, pos_head_dim, num_freqs=64) + self.rel_pos = RelPosScores(num_heads, pos_head_dim, num_freqs=num_freqs) self.copy_query = Identity() self.copy_pos_query = Identity() @@ -1436,12 +1448,64 @@ def compute_angular_freq_basis_triangular(freqs: Tensor, return torch.stack((re, im), dim=-1).to(dtype) +class AngularFreqBasis(nn.Module): + """ + Computes and caches the angular-frequency basis used in relative position scoring. + + num_freqs: the number of frequencies of the sin and cos functions + low_freq_factor: this is approximately the amount by which the lowest frequency will be + less than the highest frequency, the highest frequency being the Nyquist (pi). + The frequencies are close to a geometric series at higher frequency but linear + at low frequency. + """ + def __init__(self, num_freqs: int, low_freq_factor: float = 0.001): + super().__init__() + log_freqs = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs) + freqs = math.pi * (log_freqs.exp() - low_freq_factor) # range from 0 to pi. + freqs[0] = 0.0 # in case of roundoff + self.register_buffer('freqs', freqs, persistent=False) + + self._cached_basis: Optional[Tensor] = None + self._cached_seq_len: int = -1 + self._cached_left_context_len: int = -1 + + def forward(self, seq_len: int, left_context_len: int, device: torch.device) -> Tensor: + """ + Returns basis of shape (2 * seq_len + left_context_len - 1, 2 * num_freqs). + + The result is cached; if the requested (seq_len, left_context_len) fits + within the cached range, the cached tensor is sliced rather than + recomputed. + """ + S = self._cached_seq_len + L = self._cached_left_context_len + if (self._cached_basis is not None + and seq_len <= S + and seq_len + left_context_len <= S + L): + start = S + L - seq_len - left_context_len + end = start + 2 * seq_len + left_context_len - 1 + return self._cached_basis[start:end] + + t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=device) + basis = compute_angular_freq_basis_triangular(self.freqs, t, scale=False) + # basis: (2 * seq_len + left_context_len - 1, num_freqs, 2) + basis = basis.permute(0, 2, 1) + # permute it because of how we did the low-pass initialization of weight, we want + # the cos and sin parts to each be continuous ranges, not interleaved. + basis = basis.reshape(basis.shape[0], -1) + # basis: (2 * seq_len + left_context_len - 1, 2 * num_freqs) + + self._cached_basis = basis + self._cached_seq_len = seq_len + self._cached_left_context_len = left_context_len + return basis + + class RelPosScores(nn.Module): def __init__(self, num_heads: int, pos_head_dim: int, - num_freqs: int, - low_freq_factor: float = 0.001): + num_freqs: int): """ Implementation of relative position scores; where conventional relative position scores would use sinusoids, we treat each sinusoid frequency as the central frequency of a @@ -1458,10 +1522,6 @@ def __init__(self, be identical to the query-dim but we make the "position query" independent of the main query and with a smaller dimension. num_freqs: the number of frequencies of the sin and cos functions - low_freq_factor: this is approximately the amount by which the lowest frequency will be - less than the highest frequency, the highest frequency being the Nyquist (pi). - The frequencies are close to a geometric series at higher frequency but linear - at low frequency. """ super().__init__() self.weight = nn.Parameter(0.04 * torch.randn(num_heads, pos_head_dim, 2 * num_freqs)) @@ -1471,10 +1531,8 @@ def __init__(self, for _ in range(10): self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) - log_freqs = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs) - freqs = math.pi * (log_freqs.exp() - low_freq_factor) # these range from 0 to pi. - freqs[0] = 0.0 # in case of roundoff (it should be 0, mathematically) - self.register_buffer('freqs', freqs, persistent=False) + # angular_freq_basis will be set externally as a shared module + self.angular_freq_basis: Optional[AngularFreqBasis] = None def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: """ @@ -1491,14 +1549,8 @@ def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: """ (batch_size, num_heads, seq_len, pos_head_dim) = p.shape - freqs = self.freqs # base freqs - t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=p.device) - basis = compute_angular_freq_basis_triangular(freqs, t, scale=False) - # basis: (2 * seq_len + left_context_len - 1, num_freqs, 2) - basis = basis.permute(0, 2, 1) - # permute it because of how we did the low-pass initialization of weight, we want - # the cos and sin parts to each be continuous ranges, not interleaved. - basis = basis.reshape(basis.shape[0], -1) # (2 * seq_len + left_context_len - 1, 2 * num_freqs) + basis = self.angular_freq_basis(seq_len, left_context_len, p.device) + # basis: (2 * seq_len + left_context_len - 1, 2 * num_freqs) x = torch.matmul(self.weight, basis.t()) assert x.shape == (num_heads, pos_head_dim, 2 * seq_len + left_context_len - 1) From 565ece06ce948c0b0a8f3553336df91d2576c6e2 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 13 Apr 2026 11:44:37 +0800 Subject: [PATCH 1060/1191] minor update --- egs/librispeech/ASR/zapformer/zapformer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index e305d91372..5e830c88bf 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -176,8 +176,8 @@ def _to_tuple(x): # Share a single AngularFreqBasis instance across all layers within each encoder stack for encoder in self.encoders: - shared_basis = AngularFreqBasis(num_freqs=num_freqs) - for layer in encoder.layers: + shared_basis = encoder.layers[0].self_attn.rel_pos.angular_freq_basis + for layer in encoder.layers[1:]: layer.self_attn.rel_pos.angular_freq_basis = shared_basis self.out_norm = RmsNorm() @@ -1531,8 +1531,7 @@ def __init__(self, for _ in range(10): self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) - # angular_freq_basis will be set externally as a shared module - self.angular_freq_basis: Optional[AngularFreqBasis] = None + self.angular_freq_basis = AngularFreqBasis(num_freqs) def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: """ From 486a368f6e8034e9a1d05348179aa186c7d58898 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Apr 2026 14:41:36 +0800 Subject: [PATCH 1061/1191] Restore correlation_limiter but with power reduced from .4 to .35 --- egs/librispeech/ASR/zapformer/zapformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 4566a95fd6..4fb3871b48 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -546,8 +546,8 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) - #power = 0.4 # power should be between 0 and 1. 1 would mean cov == I (unattainable) - #self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) + power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn = MultiheadRelPosGatedSelfAttention( embed_dim, @@ -595,9 +595,9 @@ def forward( """ src_orig = src - #src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), - # 2. * aux_loss_scale, mask=src_key_padding_mask), - # None) + src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + 2. * aux_loss_scale, mask=src_key_padding_mask), + None) src_pre_ff1 = src From 21da053d46db1f7003b80c25b3fa2bb79c9cbb8c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Apr 2026 15:15:17 +0800 Subject: [PATCH 1062/1191] Introduce scale of query_head_dim**-0.5 to keys --- egs/librispeech/ASR/zapformer/zapformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 2469834429..d0491b09c0 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -981,7 +981,7 @@ def __init__( # it would be necessary to apply the scaling factor in the forward function. self.qkp_in_proj = ScaledLinear( embed_dim, in_proj_dim, - bias=True, initial_scale=0.125 * query_head_dim**-0.25 + bias=True, initial_scale=0.125, ) self.rel_pos = RelPosScores(num_heads, pos_head_dim, num_freqs=num_freqs) @@ -1038,7 +1038,7 @@ def forward( # self-attention q = x_qkp[..., 0:query_dim] - k = x_qkp[..., query_dim : 2 * query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] * (query_head_dim ** -0.5) p = x_qkp[..., 2 * query_dim:] q = self.copy_query(q) # for diagnostics only, does nothing. @@ -1450,7 +1450,7 @@ def compute_angular_freq_basis_triangular(freqs: Tensor, class AngularFreqBasis(nn.Module): """ - Computes and caches the angular-frequency basis used in relative position scoring. + Computes and caches the angular-frequency basis used in relative position scoring. num_freqs: the number of frequencies of the sin and cos functions low_freq_factor: this is approximately the amount by which the lowest frequency will be @@ -1484,7 +1484,7 @@ def forward(self, seq_len: int, left_context_len: int, device: torch.device) -> and seq_len + left_context_len <= S + L): start = S + L - seq_len - left_context_len end = start + 2 * seq_len + left_context_len - 1 - return self._cached_basis[start:end] + return self._cached_basis[start:end] t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=device) basis = compute_angular_freq_basis_triangular(self.freqs, t, scale=False) From f26a443d21a31993f31a4d34b509268c0b32feea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Apr 2026 16:40:41 +0800 Subject: [PATCH 1063/1191] Add code to compute and print projection overlap --- egs/librispeech/ASR/zapformer/train.py | 2 ++ egs/librispeech/ASR/zapformer/zapformer.py | 35 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ab4c9dcd75..d405e79d96 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,6 +1358,8 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model.encoder.compute_projection_overlap() + optimizer = Rubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index d0491b09c0..b0e99bf8be 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -288,6 +288,38 @@ def forward( return x, x_lens + + def compute_projection_overlap(self): + # between pairs of encoders + N = len(self.encoders) + for i in range(N): + for j in range(i + 1): + # multipying by lr_scale keeps the scale correct so they will be orthogonal + proj_i = self.encoders[i].proj.weight * self.encoders[i].proj.lr_scale + proj_j = self.encoders[j].proj.weight * self.encoders[j].proj.lr_scale + if proj_i.shape[1] > proj_j.shape[1]: + proj_i, proj_j = proj_j, proj_i # swap them + in_dim_i = proj_i.shape[1] # now this is <= proj_j.shape[1] + in_dim_j = proj_j.shape[1] + assert in_dim_i <= in_dim_j + assert in_dim_j % in_dim_i == 0 # in_dims must be multiples of each other + R = in_dim_j // in_dim_i # e.g. 1, 2, 4 + assert R in [1, 2, 4, 8] + new_proj_i = torch.zeros(R, proj_i.shape[0], R, proj_i.shape[1], device=proj_i.device) + for r in range(R): + new_proj_i[r, :, r, :] = proj_i + new_proj_i = new_proj_i.reshape(R * proj_i.shape[0], R * proj_i.shape[1]) + assert new_proj_i.shape[1] == proj_j.shape[1] + proj_i = new_proj_i + cov_i = torch.matmul(proj_i.t(), proj_i) + cov_j = torch.matmul(proj_j.t(), proj_j) + # denominator is the minimum of the two rather than their geometric mean, + # because due to the orthogonal constraint, the maximum possible value of (cov_i * cov_j).sum() would be the + # smaller of the two dimension. + cosine = (cov_i * cov_j).sum() / min((cov_i * cov_i).sum(), (cov_j * cov_j).sum()) + logging.info(f"overlap[{i}, {j}] = {cosine}") + + def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int ) -> Optional[Tensor]: @@ -839,6 +871,7 @@ def forward( return src + def streaming_forward( self, src: Tensor, @@ -2152,6 +2185,8 @@ def _test_zapformer_streaming(): left_context_frames=(left_context_frames,), ) + model.compute_projection_overlap() + model.eval() x_full = torch.randn(seq_len, batch_size, input_dim) From 39a56ffaa3404fc3069420c046def4167ea57f8f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 15 Apr 2026 18:24:51 +0800 Subject: [PATCH 1064/1191] Introduce a new loss term that makes the projection-overlap be at least 0.7. --- egs/librispeech/ASR/zapformer/train.py | 2 -- egs/librispeech/ASR/zapformer/zapformer.py | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d405e79d96..ab4c9dcd75 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1358,8 +1358,6 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - model.encoder.compute_projection_overlap() - optimizer = Rubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index b0e99bf8be..1cd13efcad 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -286,14 +286,20 @@ def forward( x = self.out_norm(x) + if self.training: + x = with_loss(x, aux_loss_scale * self.compute_projection_overlap()) + return x, x_lens def compute_projection_overlap(self): + min_overlap = 0.7 # we can tune this + + tot_loss = 0.0 # between pairs of encoders N = len(self.encoders) for i in range(N): - for j in range(i + 1): + for j in range(i): # multipying by lr_scale keeps the scale correct so they will be orthogonal proj_i = self.encoders[i].proj.weight * self.encoders[i].proj.lr_scale proj_j = self.encoders[j].proj.weight * self.encoders[j].proj.lr_scale @@ -317,7 +323,10 @@ def compute_projection_overlap(self): # because due to the orthogonal constraint, the maximum possible value of (cov_i * cov_j).sum() would be the # smaller of the two dimension. cosine = (cov_i * cov_j).sum() / min((cov_i * cov_i).sum(), (cov_j * cov_j).sum()) - logging.info(f"overlap[{i}, {j}] = {cosine}") + + loss = (min_overlap - cosine).relu() + tot_loss = tot_loss + loss + return tot_loss def _get_attn_mask( @@ -2185,7 +2194,7 @@ def _test_zapformer_streaming(): left_context_frames=(left_context_frames,), ) - model.compute_projection_overlap() + #model.compute_projection_overlap() model.eval() From fdd72e6a60b2f174fa531fc0a4449769c0148a40 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Apr 2026 14:11:12 +0800 Subject: [PATCH 1065/1191] take some changes to the metric from deterministic_invertible3021conv_overlap2 (measure time-averaging only0 and make the overlap metric be at least 0.66. --- egs/librispeech/ASR/zapformer/zapformer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 1cd13efcad..3fd0f9e18b 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -292,8 +292,8 @@ def forward( return x, x_lens - def compute_projection_overlap(self): - min_overlap = 0.7 # we can tune this + def compute_projection_overlap(self, verbose: bool = False): + min_overlap = 0.66 # we can tune this tot_loss = 0.0 # between pairs of encoders @@ -311,21 +311,21 @@ def compute_projection_overlap(self): assert in_dim_j % in_dim_i == 0 # in_dims must be multiples of each other R = in_dim_j // in_dim_i # e.g. 1, 2, 4 assert R in [1, 2, 4, 8] - new_proj_i = torch.zeros(R, proj_i.shape[0], R, proj_i.shape[1], device=proj_i.device) - for r in range(R): - new_proj_i[r, :, r, :] = proj_i - new_proj_i = new_proj_i.reshape(R * proj_i.shape[0], R * proj_i.shape[1]) - assert new_proj_i.shape[1] == proj_j.shape[1] - proj_i = new_proj_i + + proj_i = proj_i.repeat(1, R).reshape(proj_i.shape[0], proj_j.shape[1]) * (R ** -0.5) + # proj_i should still have orthogonal rows. + # now proj_j and proj_i have same dimension one (in_dim) cov_i = torch.matmul(proj_i.t(), proj_i) cov_j = torch.matmul(proj_j.t(), proj_j) # denominator is the minimum of the two rather than their geometric mean, # because due to the orthogonal constraint, the maximum possible value of (cov_i * cov_j).sum() would be the # smaller of the two dimension. - cosine = (cov_i * cov_j).sum() / min((cov_i * cov_i).sum(), (cov_j * cov_j).sum()) + cosine = (cov_i * cov_j).sum() / proj_i.shape[0] loss = (min_overlap - cosine).relu() tot_loss = tot_loss + loss + if verbose: + logging.info(f"overlap[{i}, {j}] = {cosine}") return tot_loss @@ -2194,7 +2194,7 @@ def _test_zapformer_streaming(): left_context_frames=(left_context_frames,), ) - #model.compute_projection_overlap() + model.compute_projection_overlap(verbose=True) model.eval() From 219a672a8ec1e931e586ffa841ffdc2575cd9310 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Apr 2026 14:44:31 +0800 Subject: [PATCH 1066/1191] Remove correlation limiter loss --- egs/librispeech/ASR/zapformer/zapformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 3fd0f9e18b..627aa13bb1 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -597,8 +597,8 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) - power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) - self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) + #power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + #self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) self.self_attn = MultiheadRelPosGatedSelfAttention( embed_dim, @@ -647,9 +647,9 @@ def forward( """ src_orig = src - src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), - 2. * aux_loss_scale, mask=src_key_padding_mask), - None) + #src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + # 2. * aux_loss_scale, mask=src_key_padding_mask), + #None) src_pre_ff1 = src From 2e847d57d2cfaa5a01eb665c4a92c8d97e459cea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Apr 2026 15:52:58 +0800 Subject: [PATCH 1067/1191] Fix bug in rubik.py about unnecessarily unsqueezing --- egs/librispeech/ASR/zapformer/rubik.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 9b2ad4bde4..238531eb57 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -318,16 +318,13 @@ def step(self, closure=None): state["step"] = 0 cur_step = 0 - def u(x): - return x.unsqueeze(0) - if p.numel() == 1: # "scalar_scale" the assumed parameter scale used for # scalars, in this case it just acts as a multiplier on # the learning rate. p += group["scalar_scale"] * adam_step(group, state, grad) else: - p += scaling_step(group, u(p.detach()), state, u(grad))[0] + p += scaling_step(group, p.detach(), state, grad) state["step"] = cur_step + 1 From 9dc723574cae4a201c674feaaf956e27ce02150c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Apr 2026 16:10:43 +0800 Subject: [PATCH 1068/1191] Use fourth_power_rms in normalizing step size in batched_rubik --- .../ASR/zapformer/batched_rubik.py | 21 ++++++++++++++++++- egs/librispeech/ASR/zapformer/rubik.py | 19 ++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index b511371bf3..1ec763ef1a 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -260,7 +260,7 @@ def min_sum_scale(x, y): # accumulator with beta equal to beta1. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - d_norm3 = d_norm2 * (assumed_scale / ((d_norm2 ** 2).mean(dim=(1, 2), keepdim=True) + eps).sqrt()) + d_norm3 = d_norm2 * (assumed_scale / (fourth_power_rms(d_norm2) + eps)) moving_update = d_norm3 @@ -318,6 +318,25 @@ def scaling_step(group, param, state, grad): return param * delta_scale + scale * delta +def fourth_power_rms(x): + # compute the RMS values of x in a way that uses fourth rather than second powers of + # singular values. Test: + # fourth_power_rms(torch.randn(2, 1000, 3)) + # tensor([[[1.0045]], + # [[1.0148]]]) + #>>> fourth_power_rms(torch.randn(2, 3, 1000)) + #tensor([[[0.9880]], + # [[0.9984]]]) + (_batch, rows, cols) = x.shape + if rows < cols: + y = torch.matmul(x, x.transpose(1, 2)) + return ((y ** 2).sum(dim=(1, 2), keepdim=True) / (rows * cols * cols)) ** 0.25 + else: + y = torch.matmul(x.transpose(1, 2), x) + return ((y ** 2).sum(dim=(1, 2), keepdim=True) / (cols * rows * rows)) ** 0.25 + + + def adam_step(group, state, grad): lr = group["lr"] step = state["step"] diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 238531eb57..7c10987c2a 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -68,6 +68,21 @@ def prod(l): return prod(shape[:i]), prod(shape[i:]) assert False, shape +def fourth_power_rms(x): + # compute the RMS values of x in a way that uses fourth rather than second powers of + # singular values. Test: + # fourth_power_rms(torch.randn(3, 1000)) + # tensor(0.9783) + # fourth_power_rms(torch.randn(1000, 3)) + # tensor(0.9880) + (rows, cols) = x.shape + if rows < cols: + y = torch.matmul(x, x.t()) + return ((y ** 2).sum() / (rows * cols * cols)) ** 0.25 + else: + y = torch.matmul(x.t(), x) + return ((y ** 2).sum() / (cols * rows * rows)) ** 0.25 + def cubic_decay_step(group, state, grad): delta = grad @@ -137,8 +152,6 @@ def min_sum_scale(x, y): d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. - d_norm1_sq = d_norm1 ** 2 - # first update row_stats. row_stats.mul_(beta2).add_((d_norm1 ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) @@ -155,7 +168,7 @@ def min_sum_scale(x, y): # accumulator with beta equal to beta1. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - d_norm3 = d_norm2 * (assumed_scale / ((d_norm2 ** 2).mean() + eps) .sqrt()) + d_norm3 = d_norm2 * (assumed_scale / (fourth_power_rms(d_norm2) + eps)) moving_update = d_norm3 From b0a1b50ceda6ba99ac21c9b50bcddae0ab6a68c6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 19 Apr 2026 13:08:55 +0800 Subject: [PATCH 1069/1191] Add model.encoder.compute_projection_overlap(verbose=True) every epoch --- egs/librispeech/ASR/zapformer/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ab4c9dcd75..635abe6a6b 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1562,6 +1562,7 @@ def remove_short_and_long_utt(c: Cut): # train_dl.sampler.set_epoch(epoch) # because we just created the sampler and its seed already depends on the epoch. + model.encoder.compute_projection_overlap(verbose=True) # for diagnostics seed = params.seed + 50 * epoch + 512 * rank From 74ca2956b1358b06133c597707e47357891dde8b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 19 Apr 2026 13:18:04 +0800 Subject: [PATCH 1070/1191] Fix bug with compute_projection_overlap loss not being scaled. --- egs/librispeech/ASR/zapformer/zapformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 627aa13bb1..0457210ee2 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -287,7 +287,9 @@ def forward( x = self.out_norm(x) if self.training: - x = with_loss(x, aux_loss_scale * self.compute_projection_overlap()) + # all of our losses and aux losses are proportional to the number of frames of data, so + # we multiply by that factor. + x = with_loss(x, aux_loss_scale * x.shape[0] * x.shape[1] * self.compute_projection_overlap()) return x, x_lens From 0887c1d64f96293333c5c98af49b66bb4094a385 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 19 Apr 2026 13:34:08 +0800 Subject: [PATCH 1071/1191] Fix printing of projection overlap per epoch, w.r.t. DDP --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 635abe6a6b..28b147841b 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1562,7 +1562,7 @@ def remove_short_and_long_utt(c: Cut): # train_dl.sampler.set_epoch(epoch) # because we just created the sampler and its seed already depends on the epoch. - model.encoder.compute_projection_overlap(verbose=True) # for diagnostics + (model.module if isinstance(model, DDP) else model).encoder.compute_projection_overlap(verbose=True) # for diagnostics seed = params.seed + 50 * epoch + 512 * rank From 71fdfaacc72b5b97da4604ef705ab383448a568e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Apr 2026 19:09:59 +0800 Subject: [PATCH 1072/1191] Change factor in setting alpha from .5 to .25, bigger safety factor to prevent divergence of cubic update --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/rubik.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 1ec763ef1a..f878a6d876 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -233,7 +233,7 @@ def min_sum_scale(x, y): prod3 = compute_scaled_prod3(d_norm1) - alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + alpha = (0.25 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) # we multiply prod3 by row_col_scale to "un-normalize". # In the normal case where we're not limited by stability-of-update-concerns, # the next line of code is equivalent to: diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 7c10987c2a..5fcb993608 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -143,7 +143,7 @@ def min_sum_scale(x, y): prod3 = compute_scaled_prod3(d_norm1) - alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + alpha = (0.25 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) # we multiply prod3 by row_col_scale to "un-normalize". # In the normal case where we're not limited by stability-of-update-concerns, # the next line of code is equivalent to: From 2d5697dc275ac1e8fd8ef4c9073c7525f049bd08 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Apr 2026 20:26:43 +0800 Subject: [PATCH 1073/1191] Do the remaining part of the shrinkage as linear shrinkage. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index f878a6d876..e1c95fcdc3 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -218,7 +218,7 @@ def min_sum_scale(x, y): delta = delta.reshape(*d.shape) d.add_(delta) # the scale used here doesn't matter as it all gets normalized. - d.mul_(1 - (linear_decay_proportion * (1 - beta1))) + #d.mul_(1 - (linear_decay_proportion * (1 - beta1))) d2 = d ** 2 @@ -234,12 +234,17 @@ def min_sum_scale(x, y): prod3 = compute_scaled_prod3(d_norm1) alpha = (0.25 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + + alpha_remaining = -(1-beta1) - alpha # will be negative. + # we multiply prod3 by row_col_scale to "un-normalize". # In the normal case where we're not limited by stability-of-update-concerns, # the next line of code is equivalent to: # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_proportion) d.add_((prod3 * row_col_scale) * alpha) + d.mul_(1. - alpha_remaining) + d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. d_norm1_sq = d_norm1 ** 2 From 65828997f53b3caf556276a00248d53b75c7cbe1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Apr 2026 20:55:04 +0800 Subject: [PATCH 1074/1191] Revert safety factor on alpha from .25 to .5 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index e1c95fcdc3..293f297399 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -233,7 +233,7 @@ def min_sum_scale(x, y): prod3 = compute_scaled_prod3(d_norm1) - alpha = (0.25 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) alpha_remaining = -(1-beta1) - alpha # will be negative. From 6e391a649a52b2e562bb2dda15ce5bd8c8a9c842 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Apr 2026 21:46:38 +0800 Subject: [PATCH 1075/1191] Increase safety factor from 0.5 to 0.66 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 293f297399..a3ab4885bd 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -233,7 +233,7 @@ def min_sum_scale(x, y): prod3 = compute_scaled_prod3(d_norm1) - alpha = (0.5 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + alpha = (0.66 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) alpha_remaining = -(1-beta1) - alpha # will be negative. From 956a015dca894c1db046bd90be44c565d65dac32 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Apr 2026 22:37:58 +0800 Subject: [PATCH 1076/1191] Incorporate refactoring of rubik, with safety_factor=0.66 and use the remaining shrinkage as linear shrinkage --- .../ASR/zapformer/batched_rubik.py | 337 ++++++++++-------- egs/librispeech/ASR/zapformer/rubik.py | 312 +++++++++------- 2 files changed, 368 insertions(+), 281 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a3ab4885bd..aab59bbadc 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -133,222 +133,258 @@ def compute_prod3(x): x2 = torch.matmul(x.transpose(-2, -1), x) return torch.matmul(x, x2) -def compute_scaled_prod3(x): - # computes 3-way matrix power x^3 (x is treated as a batch of matrices) with a scaling such that (for each - # matrix in the batch) if all the singular values of the matrix are the same, the result will be identical to the input. +def three_way_product(x): + """ returns the 3-way matrix product x @ x.t() @ x """ + assert x.ndim >= 2 + if x.shape[0] <= x.shape[1]: + x2 = torch.matmul(x, x.transpose(-2, -1)) + return torch.matmul(x2, x) + else: + x2 = torch.matmul(x.transpose(-2, -1), x) + return torch.matmul(x, x2) - rows, cols = x.shape[-2], x.shape[-1] +def scaled_three_way_product(x): + """ + Returns alpha * (x @ x.t() @ x), + where alpha is computed from the 2-norm of x in such a way that if all the singular values of + x are the same, it will return x itself. (There is only one such formula.) If the singular + values of x differ from each other, the result will in general have a larger norm than x. + """ + rows, cols = x.shape[-2], x.shape[-1] eps = 1.0e-40 x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps x = x * (x_meansq * max(rows, cols)) ** (-1/3) - return compute_prod3(x) + return three_way_product(x) + +def clip_alpha(x: Tensor, y: Tensor, alpha: float) -> Tensor: + """ + In a situation where you plan to do: + x.add_(y, alpha=alpha) + returns a possibly-modified value of alpha that + but modified to prevent divergence on x (may use an alpha closer zero if necessary) + """ + # min_sum_scale the scale beta such that (x + beta y) is minimized; x and + # y each have 2 dimensions. min_sum_scale is expected to be negative. + min_sum_scale = -(x * y).sum(dim=(1, 2), keepdim=True) / ((y ** 2).sum(dim=(1, 2), keepdim=True) + 1.0e-40) + # the safety factor of 0.66 means, don't go all the way to where the dot product of the + # change to x with x would be zero, only go some way to there. + safety_factor = 0.66 + alpha = (safety_factor * min_sum_scale).clamp(min=alpha) + return alpha -def get_matrix_shape(shape): +def matrix_shape(shape): + """ + shape is expected to be a torch.Size or a list with at least two dimensions. + Returns (rows, cols) such that a tensor of shape `shape` can be reshaped + to size (rows, cols), by combining dimensions in a way that minimizes the + difference between rows and cols. e.g. matrix_shape([ 2, 3, 10 ]) = (6, 10) + """ shape = list(shape) - batch_size = shape[0] # batch size is 1st element of shape - shape = shape[1:] - def prod(l): - ans = l[0] - for n in l[1:]: - ans = ans * n - return ans - n = len(shape) - diffs = [ ] - for i in range(1, n): - prod1 = prod(shape[:i]) - prod2 = prod(shape[i:]) - diff = abs(prod1 - prod2) - diffs.append(diff) + cumprod = [ ] + numel = 1 + for k in shape: + cumprod.append(k) + numel = numel * k + diffs = [ abs(k - numel // k) for k in cumprod ] min_diff = min(diffs) - for i in range(1, n): - if diffs[i-1] == min_diff: - return batch_size, prod(shape[:i]), prod(shape[i:]) + for i in range(len(shape)): + if diffs[i] == min_diff: + return cumprod[i], numel // cumprod[i] + assert False, shape -def cubic_decay_step(group, state, grad): - delta = grad +def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + row_stats.mul_(beta2).add_((x ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + x = x / (row_stats.sqrt() + eps) + col_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + return x / (col_stats.sqrt() + eps) + + +def no_momentum_step(group, state, grad): + # computes an update direction using magnitude normalization but no momentum + # (no beta1, in adam terminology). grad is assumed to have exactly three + # dimensions (grad.ndim == 3), representing (batch_size, rows, cols). + # the grad is normalized using adafactor-like + # row and column statistics, but done sequentially over first rows and then + # columns + step = state["step"] + lr = group["lr"] + eps = group["eps"] + + # the following modification to beta2 warms up beta2 gradually. + # For the first step we just take the current stats; this is similar to + # a sign-only update. + beta2 = min(group["beta2"], 1. - 1. / (1. + 0.2 * step)) + + (batch_size, rows, cols) = grad.shape + try: + row_stats = state["direct_row_stats"] + col_stats = state["direct_col_stats"] + except KeyError: + row_stats = torch.zeros(batch_size, rows, 1, device=grad.device, dtype=grad.dtype) + col_stats = torch.zeros(batch_size, 1, cols, device=grad.device, dtype=grad.dtype) + state["direct_row_stats"] = row_stats + state["direct_col_stats"] = col_stats + + return -lr * normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + + +def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] beta_ceil = 1. - 1. / (10. + 0.2 * step) beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) - direct = group["direct"] + direct = group["direct"] # scale on non-momentum step cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion - try: - stored_delta = state["delta"] - except KeyError as e: + orig_shape = grad.shape + batch_size = orig_shape[0] + rows, cols = matrix_shape(orig_shape[1:]) + grad = grad.reshape(batch_size, rows, cols) + + if "moving_grad" not in state: assert step < 2 - # scalar. use conventional momentum. - stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) - state["delta"] = stored_delta - - def min_sum_scale(x, y): - # returns the scale alpha such that (x + alpha y) is minimized. x and y have - # the same shape and the shape of alpha is (x.shape[0], 1, 1, ...). - assert x.ndim > 1 - dims = list(range(1, x.ndim)) - yy = (y ** 2).sum(dim=dims, keepdim=True) - xy = (y * x).sum(dim=dims, keepdim=True) - # sum square of x + alpha y is xx + alpha^2 yy + 2 alpha xy - # d/dalpha[that] = 2 alpha yy + 2 xy - # alpha = xy / yy - return -xy / (yy + eps) - - d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) - assert d.untyped_storage() is stored_delta.untyped_storage() - (batch_size, rows, cols) = d.shape - - if "row_stats" not in state: - state["row_stats"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) - state["direct_row_stats"] = torch.ones(d.shape[0], d.shape[1], 1, device=d.device, dtype=d.dtype) - state["col_stats"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) - state["direct_col_stats"] = torch.ones(d.shape[0], 1, d.shape[2], device=d.device, dtype=d.dtype) + state["moving_grad"] = torch.zeros(batch_size, rows, cols, device=grad.device) + state["row_stats"] = torch.ones(batch_size, rows, 1, device=grad.device) + state["col_stats"] = torch.ones(batch_size,1, cols, device=grad.device) + moving_grad = state["moving_grad"] row_stats = state["row_stats"] col_stats = state["col_stats"] - direct_row_stats = state["direct_row_stats"] - direct_col_stats = state["direct_col_stats"] - - delta = delta.reshape(*d.shape) - d.add_(delta) # the scale used here doesn't matter as it all gets normalized. - #d.mul_(1 - (linear_decay_proportion * (1 - beta1))) + # add the grad to the moving-average grad; the scaling factor used here + # doesn't matter as it all gets normalized later. + moving_grad.add_(grad) - d2 = d ** 2 + # We'll scale both before and after the cubic decay; this can be viewed as + # doing the cubic decay in a preconditioned space where the preconditioner + # is 1 / row_col_denom. (The row and column stats will be updated later). + # Looking at this code may give the impression that we are mistakenly + # normalizing "twice". Actually we have an "equilibrium argument" why this + # is actually OK and will give correctly-normalized data. + row_denom = (row_stats.sqrt() + eps) + col_denom = (col_stats.sqrt() + eps) + invP = row_denom * col_denom # inverse preconditioner P - # we'll scale both before and after the cubing. - # the lines where we divide by sqrt of the mean are so we don't double - # count the scalar component of this. - row_scale = (row_stats + eps).sqrt() - col_scale = (col_stats + eps).sqrt() - row_col_scale = row_scale * col_scale + moving_grad_precon = moving_grad / invP # preconditioned moving_grad - d_norm1 = d / row_col_scale # this is the first of two steps of normalizing by these stats. + # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were + # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. + prod3 = scaled_three_way_product(moving_grad_precon) - prod3 = compute_scaled_prod3(d_norm1) + cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) + # cubic_alpha shape: (batch_size, 1, 1) - alpha = (0.66 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) + linear_alpha = -(1-beta1) - cubic_alpha # will be negative. - alpha_remaining = -(1-beta1) - alpha # will be negative. + # the next line undoes the preconditioning so we can accumulate gradient + # stats in the "canonical basis" of the gradients, for consistency. + moving_grad_cubic_decay = moving_grad_precon * invP + moving_grad_linear_decay = moving_grad * beta1 - # we multiply prod3 by row_col_scale to "un-normalize". - # In the normal case where we're not limited by stability-of-update-concerns, - # the next line of code is equivalent to: - # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_proportion) - d.add_((prod3 * row_col_scale) * alpha) + moving_grad_precon.add_(prod3 * cubic_alpha) + moving_grad_precon.mul_(1. - linear_alpha) - d.mul_(1. - alpha_remaining) + # update moving_grad as interpolation between linear decay and cubic decay. + moving_grad[:] = moving_grad_precon * invP - d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. + # Now compute "negative_update" which is negative_update_precon multiplied again by the + # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update + # rather than as a gradient. it is going to be very close to: + # negative_update = moving_grad_precon / invP + # but we also update the preconditioner. Note: practically speaking we are multiplying + # by the same thing twice though. + negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) - d_norm1_sq = d_norm1 ** 2 - - # first update row_stats. - row_stats.mul_(beta2).add_((d_norm1 ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) - - # d_norm1b means we've doing the second normalization but only by rows so far. - d_norm1b = d_norm1 / (row_stats + eps).sqrt() - - col_stats.mul_(beta2).add_((d_norm1b ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) - - d_norm2 = d_norm1b / (col_stats + eps).sqrt() - - # do "immediate" normalization of total norm to make the overall scale of the update what + # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style # accumulator with beta equal to beta1. + # This should make divergence less likely. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - d_norm3 = d_norm2 * (assumed_scale / (fourth_power_rms(d_norm2) + eps)) - - moving_update = d_norm3 + negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) if direct == 0.0: - return -lr * moving_update.reshape(*grad.shape) - - # row/col normalization of direct/bypass gradient "delta". - direct_row_stats.mul_(beta2).add_((delta ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) - delta = delta / (direct_row_stats + eps).sqrt() - direct_col_stats.mul_(beta2).add_((delta ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) - delta = delta / (direct_col_stats + eps).sqrt() + ans = -lr * negative_update + else: + ans = ((1. - direct) * -lr) * negative_update + direct * no_momentum_step(group, state, grad) - ans = (-lr * (1-direct)) * moving_update + (-lr * direct) * delta - return ans.reshape(*grad.shape) + return ans.reshape(orig_shape) def scaling_step(group, param, state, grad): + # we reach here for biases and weights but not scalars. + # This does three things things: + # (i) multiply the step from "cubic_decay" by an estimate of the parameter scale + # (ii) apply parameter decay + # (iii) update the parameter scale, which means shrinking or growing the whole tensor lr = group["lr"] - - momentum = 0.95 - is_weight = grad.ndim >= 3 + momentum = group["scale_momentum"] # e.g. 0.95 + is_weight = grad.ndim >= 2 min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] - # the "scale" is implicitly a scalar, even though it is learned in log space; apply scalar_scale to its + # the scaling factor is implicitly a scalar; apply scalar_scale to its # learning rate. scalar_scale = group["scalar_scale"] - if grad.ndim >= 3 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): + if grad.ndim >= 2 and grad.numel() != max(grad.shape): delta = cubic_decay_step(group, state, grad) else: # biases and similar-shaped tensors delta = adam_step(group, state, grad) + dims = list(range(1, param.ndim)) + try: scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] - except: - scale = (param ** 2).mean(dim=list(range(1, param.ndim)), keepdim=True).sqrt().clamp(min=min_scale, max=max_scale).to(torch.float) + except KeyError: + scale = (param ** 2).mean(dim=dims, keepdim=True).sqrt().clamp( + min=min_scale, max=max_scale).to(torch.float) scale_grad_buf = torch.zeros_like(scale) state["scale"] = scale state["scale_grad_buffer"] = scale_grad_buf - dims = list(range(1, param.ndim)) - scale_grad = (grad * param.detach()).sum(dim=dims, keepdim=True) - scale_grad_buf.mul_(momentum).add_(scale_grad) + scale_grad_buf.mul_(momentum).add_(scale_grad) # simple momentum old_scale = scale.clone() - scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr * scalar_scale) + scale.mul_(1. - lr * scalar_scale * scale_grad_buf.sign()) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale - delta_scale = (scale_ratio * (1 - lr ** 2)) - 1 + delta_scale = (scale_ratio * (1 - (lr ** 2))) - 1 return param * delta_scale + scale * delta -def fourth_power_rms(x): - # compute the RMS values of x in a way that uses fourth rather than second powers of - # singular values. Test: - # fourth_power_rms(torch.randn(2, 1000, 3)) - # tensor([[[1.0045]], - # [[1.0148]]]) - #>>> fourth_power_rms(torch.randn(2, 3, 1000)) - #tensor([[[0.9880]], - # [[0.9984]]]) - (_batch, rows, cols) = x.shape - if rows < cols: - y = torch.matmul(x, x.transpose(1, 2)) - return ((y ** 2).sum(dim=(1, 2), keepdim=True) / (rows * cols * cols)) ** 0.25 - else: - y = torch.matmul(x.transpose(1, 2), x) - return ((y ** 2).sum(dim=(1, 2), keepdim=True) / (cols * rows * rows)) ** 0.25 - - - def adam_step(group, state, grad): + # this is the adam update but with a slight modification / simplification on + # how "bias correction" (startup on small step counts) is dealt with. lr = group["lr"] step = state["step"] eps = group["eps"] - # just hardcode these. we only use this code for biases and scalars. - beta1 = 0.98 - beta2 = 0.98 + beta1 = group["adam_beta1"] + # the following modification to beta2 makes it unnecessary to do bias correction; + # for small step values, we are just computing the mean over the steps so far + beta2 = min(group["adam_beta2"], step / (step + 1)) try: exp_avg = state["exp_avg"] @@ -362,11 +398,7 @@ def adam_step(group, state, grad): exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - bias_correction2 = 1 - beta2 ** (step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - denom = (exp_avg_sq + eps).sqrt() + denom = exp_avg_sq.sqrt() + eps return -lr * (exp_avg / denom) @@ -409,10 +441,13 @@ def __init__( direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, - eps=1.0e-16, + eps=1.0e-08, weight_scale_limits=(0.05, 0.25), bias_scale_limits=(0.05, 0.25), scalar_scale=0.075, + adam_beta1=0.98, + adam_beta2=0.98, + scale_momentum=0.95, ): defaults = dict( @@ -425,6 +460,9 @@ def __init__( weight_scale_limits=weight_scale_limits, bias_scale_limits=bias_scale_limits, scalar_scale=scalar_scale, + adam_beta1=adam_beta1, + adam_beta2=adam_beta2, + scale_momentum=scale_momentum, ) param_groups, parameters_names = self._get_names_of_parameters(params) @@ -581,7 +619,6 @@ def step(self, closure=None): state["step"] = cur_step + 1 - return loss @@ -625,11 +662,11 @@ def _test_batched_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.015 + lr = 0.017 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - optim = BatchedRubik(m.parameters(), lr=lr, direct=0.0, beta1=0.999) + optim = BatchedRubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999) num_epochs = 180 @@ -695,18 +732,18 @@ def lr_lambda(current_step): -def _test_compute_scaled_prod3(): +def _test_scaled_three_way_product(): x = torch.randn(3, 16, 32) _U, _S, V = torch.linalg.svd(x, full_matrices=False) W = V * torch.randn(3, 1, 1) # so now all the singular values of x will be identical (but arbitrary) - X = compute_scaled_prod3(W) + X = scaled_three_way_product(W) #print("X = ", X[0]) #print("W = ", W[0]) assert torch.allclose(W, X, atol=1.0e-03) # but the result won't be identical to the input if the singular values are not all identical. - assert not torch.allclose(x, compute_scaled_prod3(x), atol=1.0e-03) + assert not torch.allclose(x, scaled_three_way_product(x), atol=1.0e-03) if __name__ == "__main__": torch.set_num_threads(1) @@ -725,5 +762,5 @@ def _test_compute_scaled_prod3(): else: hidden_dim = 200 - _test_compute_scaled_prod3() + _test_scaled_three_way_product() _test_batched_rubik(hidden_dim) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 5fcb993608..5d1f0304b7 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -27,175 +27,216 @@ from torch.optim import Optimizer -def compute_prod3(x): - assert x.ndim >= 2 - if x.shape[-2] <= x.shape[-1]: - x2 = torch.matmul(x, x.transpose(-2, -1)) +def three_way_product(x): + """ returns the 3-way matrix product x @ x.t() @ x """ + assert x.ndim == 2 + if x.shape[0] <= x.shape[1]: + x2 = torch.matmul(x, x.t()) return torch.matmul(x2, x) else: - x2 = torch.matmul(x.transpose(-2, -1), x) + x2 = torch.matmul(x.t(), x) return torch.matmul(x, x2) -def compute_scaled_prod3(x): - # computes 3-way matrix power x^3 (x is treated as a batch of matrices) with a scaling such that (for each - # matrix in the batch) if all the singular values of the matrix are the same, the result will be identical to the input. - - rows, cols = x.shape[-2], x.shape[-1] - +def scaled_three_way_product(x): + """ + Returns alpha * (x @ x.t() @ x), + where alpha is computed from the 2-norm of x in such a way that if all the singular values of + x are the same, it will return x itself. (There is only one such formula.) If the singular + values of x differ from each other, the result will in general have a larger norm than x. + """ + rows, cols = x.shape eps = 1.0e-40 x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps x = x * (x_meansq * max(rows, cols)) ** (-1/3) - return compute_prod3(x) + return three_way_product(x) + +def clip_alpha(x: Tensor, y: Tensor, alpha: float) -> Tensor: + """ + In a situation where you plan to do: + x.add_(y, alpha=alpha) + returns a possibly-modified value of alpha that + but modified to prevent divergence on x (may use an alpha closer zero if necessary) + """ + # min_sum_scale the scale beta such that (x + beta y) is minimized; x and + # y each have 2 dimensions. min_sum_scale is expected to be negative. + min_sum_scale = -(x * y).sum() / ((y ** 2).sum() + 1.0e-40) + # the safety factor of 0.66 means, don't go all the way to where the dot product of the + # change to x with x would be zero, only go some way to there. + safety_factor = 0.66 + alpha = (safety_factor * min_sum_scale).clamp(min=alpha) + return alpha + + + -def get_matrix_shape(shape): +def matrix_shape(shape): + """ + shape is expected to be a torch.Size with at least two dimensions. + Returns (rows, cols) such that a tensor of shape `shape` can be reshaped + to size (rows, cols), by combining dimensions in a way that minimizes the + difference between rows and cols. e.g. matrix_shape([ 2, 3, 10 ]) = (6, 10) + """ shape = list(shape) - def prod(l): - ans = l[0] - for n in l[1:]: - ans = ans * n - return ans - n = len(shape) - diffs = [ ] - for i in range(1, n): - prod1 = prod(shape[:i]) - prod2 = prod(shape[i:]) - diff = abs(prod1 - prod2) - diffs.append(diff) + cumprod = [ ] + numel = 1 + for k in shape: + cumprod.append(k) + numel = numel * k + diffs = [ abs(k - numel // k) for k in cumprod ] min_diff = min(diffs) - for i in range(1, n): - if diffs[i-1] == min_diff: - return prod(shape[:i]), prod(shape[i:]) + for i in range(len(shape)): + if diffs[i] == min_diff: + return cumprod[i], numel // cumprod[i] assert False, shape -def fourth_power_rms(x): - # compute the RMS values of x in a way that uses fourth rather than second powers of - # singular values. Test: - # fourth_power_rms(torch.randn(3, 1000)) - # tensor(0.9783) - # fourth_power_rms(torch.randn(1000, 3)) - # tensor(0.9880) - (rows, cols) = x.shape - if rows < cols: - y = torch.matmul(x, x.t()) - return ((y ** 2).sum() / (rows * cols * cols)) ** 0.25 - else: - y = torch.matmul(x.t(), x) - return ((y ** 2).sum() / (cols * rows * rows)) ** 0.25 +def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (rows, cols) +row_stats: (rows, 1) +col_stats: (1, cols) + Returns: + normalized x, shape: (rows, cols) + """ + row_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + x = x / (row_stats.sqrt() + eps) + col_stats.mul_(beta2).add_((x ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) + return x / (col_stats.sqrt() + eps) + + +def no_momentum_step(group, state, grad): + # computes an update direction using magnitude normalization but no momentum + # (no beta1, in adam terminology). grad is assumed to have exactly two + # dimensions (grad.ndim == 2). the grad is normalized using adafactor-like + # row and column statistics, but done sequentially over first rows and then + # columns + step = state["step"] + lr = group["lr"] + eps = group["eps"] + + # the following modification to beta2 warms up beta2 gradually. + # For the first step we just take the current stats; this is similar to + # a sign-only update. + beta2 = min(group["beta2"], 1. - 1. / (1. + 0.2 * step)) -def cubic_decay_step(group, state, grad): - delta = grad + (rows, cols) = grad.shape + try: + row_stats = state["direct_row_stats"] + col_stats = state["direct_col_stats"] + except KeyError: + row_stats = torch.zeros(rows, 1, device=grad.device, dtype=grad.dtype) + col_stats = torch.zeros(1, cols, device=grad.device, dtype=grad.dtype) + state["direct_row_stats"] = row_stats + state["direct_col_stats"] = col_stats + + return -lr * normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + + +def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] beta_ceil = 1. - 1. / (10. + 0.2 * step) beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) - direct = group["direct"] + + direct = group["direct"] # scale on non-momentum step cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion - try: - stored_delta = state["delta"] - except KeyError as e: - assert step < 2 - # scalar. use conventional momentum. - stored_delta = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) - state["delta"] = stored_delta - - def min_sum_scale(x, y): - # returns the scale alpha such that (x + alpha y) is minimized; x and - # y each have 2 dimensions. - return -(x * y).sum() / ((y ** 2).sum() + eps) + orig_shape = grad.shape + rows, cols = matrix_shape(orig_shape) + grad = grad.reshape(rows, cols) - d = stored_delta.reshape(get_matrix_shape(stored_delta.shape)) - assert d.untyped_storage() is stored_delta.untyped_storage() - (rows, cols) = d.shape - - if "row_stats" not in state: - state["row_stats"] = torch.ones(rows, 1, device=d.device, dtype=d.dtype) - state["direct_row_stats"] = torch.ones(rows, 1, device=d.device, dtype=d.dtype) - state["col_stats"] = torch.ones(1, cols, device=d.device, dtype=d.dtype) - state["direct_col_stats"] = torch.ones(1, cols, device=d.device, dtype=d.dtype) + if "moving_grad" not in state: + assert step < 2 + state["moving_grad"] = torch.zeros(rows, cols, device=grad.device) + state["row_stats"] = torch.ones(rows, 1, device=grad.device) + state["col_stats"] = torch.ones(1, cols, device=grad.device) + moving_grad = state["moving_grad"] row_stats = state["row_stats"] col_stats = state["col_stats"] - direct_row_stats = state["direct_row_stats"] - direct_col_stats = state["direct_col_stats"] - - delta = delta.reshape(*d.shape) - d.add_(delta) # the scale used here doesn't matter as it all gets normalized. - d.mul_(1 - (linear_decay_proportion * (1 - beta1))) + # add the grad to the moving-average grad; the scaling factor used here + # doesn't matter as it all gets normalized later. + moving_grad.add_(grad) - d2 = d ** 2 + # We'll scale both before and after the cubic decay; this can be viewed as + # doing the cubic decay in a preconditioned space where the preconditioner + # is 1 / row_col_denom. (The row and column stats will be updated later). + # Looking at this code may give the impression that we are mistakenly + # normalizing "twice". Actually we have an "equilibrium argument" why this + # is actually OK and will give correctly-normalized data. + row_denom = (row_stats.sqrt() + eps) + col_denom = (col_stats.sqrt() + eps) + invP = row_denom * col_denom # inverse preconditioner P - # we'll scale both before and after the cubing. - # the lines where we divide by sqrt of the mean are so we don't double - # count the scalar component of this. - row_scale = (row_stats + eps).sqrt() - col_scale = (col_stats + eps).sqrt() - row_col_scale = row_scale * col_scale + moving_grad_precon = moving_grad / invP # preconditioned moving_grad - d_norm1 = d / row_col_scale # this is the first of two steps of normalizing by these stats. + # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were + # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. + prod3 = scaled_three_way_product(moving_grad_precon) - prod3 = compute_scaled_prod3(d_norm1) + # next line similar to: + # moving_grad_precon.add_(prod3, alpha=-(1-beta1)) + # but with a precaution for divergence. - alpha = (0.25 * min_sum_scale(d_norm1, prod3)).clamp(min=-cubic_decay_proportion*(1-beta1)) - # we multiply prod3 by row_col_scale to "un-normalize". - # In the normal case where we're not limited by stability-of-update-concerns, - # the next line of code is equivalent to: - # d.add_(prod3 * row_col_scale, alpha=-cubic_decay_proportion) - d.add_((prod3 * row_col_scale) * alpha) + cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) + # cubic_alpha shape: (batch_size, 1, 1) - d_norm1 = d / row_col_scale # updated version of d_norm1 with x3 term subtracted. + linear_alpha = -(1-beta1) - cubic_alpha # will be negative. - # first update row_stats. - row_stats.mul_(beta2).add_((d_norm1 ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + moving_grad_precon.add_(prod3 * cubic_alpha) + moving_grad_precon.mul_(1. - linear_alpha) - # d_norm1b means we've doing the second normalization but only by rows so far. - d_norm1b = d_norm1 / (row_stats + eps).sqrt() + # update moving_grad as interpolation between linear decay and cubic decay. + moving_grad[:] = moving_grad_precon * invP - col_stats.mul_(beta2).add_((d_norm1b ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) + # Now compute "negative_update" which is negative_update_precon multiplied again by the + # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update + # rather than as a gradient. it is going to be very close to: + # negative_update = moving_grad_precon / invP + # but we also update the preconditioner. Note: practically speaking we are multiplying + # by the same thing twice though. + negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) - d_norm2 = d_norm1b / (col_stats + eps).sqrt() - - # do "immediate" normalization of total norm to make the overall scale of the update what + # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style # accumulator with beta equal to beta1. + # This should make divergence less likely. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - d_norm3 = d_norm2 * (assumed_scale / (fourth_power_rms(d_norm2) + eps)) - - moving_update = d_norm3 + negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean().sqrt() + eps)) if direct == 0.0: - return -lr * moving_update.reshape(*grad.shape) - - # row/col normalization of direct/bypass gradient "delta". - direct_row_stats.mul_(beta2).add_((delta ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) - delta = delta / (direct_row_stats + eps).sqrt() - direct_col_stats.mul_(beta2).add_((delta ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) - delta = delta / (direct_col_stats + eps).sqrt() + ans = -lr * negative_update + else: + ans = ((1. - direct) * -lr) * negative_update + direct * no_momentum_step(group, state, grad) - ans = (-lr * (1-direct)) * moving_update + (-lr * direct) * delta - return ans.reshape(*grad.shape) + return ans.reshape(orig_shape) def scaling_step(group, param, state, grad): + # we reach here for biases and weights but not scalars. + # This does three things things: + # (i) multiply the step from "cubic_decay" by an estimate of the parameter scale + # (ii) apply parameter decay + # (iii) update the parameter scale, which means shrinking or growing the whole tensor lr = group["lr"] - - momentum = 0.95 + momentum = group["scale_momentum"] # e.g. 0.95 is_weight = grad.ndim >= 2 min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] - # the "scale" is implicitly a scalar, even though it is learned in log space; apply scalar_scale to its + # the scaling factor is implicitly a scalar; apply scalar_scale to its # learning rate. scalar_scale = group["scalar_scale"] - if grad.ndim >= 2 and grad.numel() != max(grad.shape): delta = cubic_decay_step(group, state, grad) else: @@ -205,19 +246,20 @@ def scaling_step(group, param, state, grad): try: scale = state["scale"] scale_grad_buf = state["scale_grad_buffer"] - except: - scale = (param ** 2).mean().sqrt().clamp(min=min_scale, max=max_scale).to(torch.float) + except KeyError: + scale = (param ** 2).mean().sqrt().clamp(min=min_scale, + max=max_scale).to(torch.float) scale_grad_buf = torch.zeros_like(scale) state["scale"] = scale state["scale_grad_buffer"] = scale_grad_buf scale_grad = (grad * param.detach()).sum() - scale_grad_buf.mul_(momentum).add_(scale_grad) + scale_grad_buf.mul_(momentum).add_(scale_grad) # simple momentum old_scale = scale.clone() - scale.add_(scale_grad_buf.sign() * old_scale, alpha=-lr * scalar_scale) + scale.mul_(1. - lr * scalar_scale * scale_grad_buf.sign()) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale @@ -227,12 +269,15 @@ def scaling_step(group, param, state, grad): def adam_step(group, state, grad): + # this is the adam update but with a slight modification / simplification on + # how "bias correction" (startup on small step counts) is dealt with. lr = group["lr"] step = state["step"] eps = group["eps"] - # just hardcode these. we only use this code for biases and scalars. - beta1 = 0.98 - beta2 = 0.98 + beta1 = group["adam_beta1"] + # the following modification to beta2 makes it unnecessary to do bias correction; + # for small step values, we are just computing the mean over the steps so far + beta2 = min(group["adam_beta2"], step / (step + 1)) try: exp_avg = state["exp_avg"] @@ -246,11 +291,7 @@ def adam_step(group, state, grad): exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - bias_correction2 = 1 - beta2 ** (step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - denom = (exp_avg_sq + eps).sqrt() + denom = exp_avg_sq.sqrt() + eps return -lr * (exp_avg / denom) @@ -282,10 +323,13 @@ def __init__( direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, - eps=1.0e-16, + eps=1.0e-08, weight_scale_limits=(0.05, 0.25), bias_scale_limits=(0.05, 0.25), scalar_scale=0.075, + adam_beta1=0.98, + adam_beta2=0.98, + scale_momentum=0.95, ): defaults = dict( lr=lr, @@ -297,6 +341,9 @@ def __init__( weight_scale_limits=weight_scale_limits, bias_scale_limits=bias_scale_limits, scalar_scale=scalar_scale, + adam_beta1=adam_beta1, + adam_beta2=adam_beta2, + scale_momentum=scale_momentum, ) super().__init__(params, defaults) @@ -331,6 +378,9 @@ def step(self, closure=None): state["step"] = 0 cur_step = 0 + def u(x): + return x.unsqueeze(0) + if p.numel() == 1: # "scalar_scale" the assumed parameter scale used for # scalars, in this case it just acts as a multiplier on @@ -384,11 +434,11 @@ def _test_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.015 + lr = 0.017 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - optim = Rubik(m.parameters(), lr=lr, direct=0.0, beta1=0.999) + optim = Rubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999) num_epochs = 180 @@ -453,18 +503,18 @@ def lr_lambda(current_step): logging.info(f"output_magnitudes = {output_magnitudes}") -def _test_compute_scaled_prod3(): +def _test_scaled_three_way_product(): x = torch.randn(16, 32) _U, _S, V = torch.linalg.svd(x, full_matrices=False) W = V * torch.randn(1, 1) # so now all the singular values of x will be identical (but arbitrary) - X = compute_scaled_prod3(W) + X = scaled_three_way_product(W) #print("X = ", X[0]) #print("W = ", W[0]) assert torch.allclose(W, X, atol=1.0e-03) # but the result won't be identical to the input if the singular values are not all identical. - assert not torch.allclose(x, compute_scaled_prod3(x), atol=1.0e-03) + assert not torch.allclose(x, scaled_three_way_product(x), atol=1.0e-03) if __name__ == "__main__": torch.set_num_threads(1) @@ -483,5 +533,5 @@ def _test_compute_scaled_prod3(): else: hidden_dim = 200 - _test_compute_scaled_prod3() + _test_scaled_three_way_product() _test_rubik(hidden_dim) From 4bbd3eb5004778239224ff3e2cc511271e4ca046 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 20 Apr 2026 23:02:43 +0800 Subject: [PATCH 1077/1191] Reduce safety_factor from 0.66 to 0.5 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ++-- egs/librispeech/ASR/zapformer/rubik.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index aab59bbadc..e68e0cff77 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -167,9 +167,9 @@ def clip_alpha(x: Tensor, y: Tensor, alpha: float) -> Tensor: # min_sum_scale the scale beta such that (x + beta y) is minimized; x and # y each have 2 dimensions. min_sum_scale is expected to be negative. min_sum_scale = -(x * y).sum(dim=(1, 2), keepdim=True) / ((y ** 2).sum(dim=(1, 2), keepdim=True) + 1.0e-40) - # the safety factor of 0.66 means, don't go all the way to where the dot product of the + # the safety factor of 0.5 means, don't go all the way to where the dot product of the # change to x with x would be zero, only go some way to there. - safety_factor = 0.66 + safety_factor = 0.5 alpha = (safety_factor * min_sum_scale).clamp(min=alpha) return alpha diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 5d1f0304b7..cd44d8a657 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -60,9 +60,9 @@ def clip_alpha(x: Tensor, y: Tensor, alpha: float) -> Tensor: # min_sum_scale the scale beta such that (x + beta y) is minimized; x and # y each have 2 dimensions. min_sum_scale is expected to be negative. min_sum_scale = -(x * y).sum() / ((y ** 2).sum() + 1.0e-40) - # the safety factor of 0.66 means, don't go all the way to where the dot product of the + # the safety factor of 0.5 means, don't go all the way to where the dot product of the # change to x with x would be zero, only go some way to there. - safety_factor = 0.66 + safety_factor = 0.5 alpha = (safety_factor * min_sum_scale).clamp(min=alpha) return alpha From 28fb712a0820f8e2f236943249b8e1b6768ccf5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Apr 2026 12:59:36 +0800 Subject: [PATCH 1078/1191] Bug fix in sign of linear decay in rubik --- egs/librispeech/ASR/zapformer/batched_rubik.py | 11 +++-------- egs/librispeech/ASR/zapformer/rubik.py | 8 ++++---- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index e68e0cff77..aab5086a75 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -293,13 +293,8 @@ def cubic_decay_step(group, state, grad): linear_alpha = -(1-beta1) - cubic_alpha # will be negative. - # the next line undoes the preconditioning so we can accumulate gradient - # stats in the "canonical basis" of the gradients, for consistency. - moving_grad_cubic_decay = moving_grad_precon * invP - moving_grad_linear_decay = moving_grad * beta1 - moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. - linear_alpha) + moving_grad_precon.mul_(1. + linear_alpha) # equiv to: moving_grad_precon.add_(moving_grad_precon, alpha=linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP @@ -309,7 +304,7 @@ def cubic_decay_step(group, state, grad): # rather than as a gradient. it is going to be very close to: # negative_update = moving_grad_precon / invP # but we also update the preconditioner. Note: practically speaking we are multiplying - # by the same thing twice though. + # by the same thing twice, i.e. dividing "grad" twice by invP. negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what @@ -662,7 +657,7 @@ def _test_batched_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.017 + lr = 0.015 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index cd44d8a657..9920640ee3 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -188,12 +188,12 @@ def cubic_decay_step(group, state, grad): # but with a precaution for divergence. cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) - # cubic_alpha shape: (batch_size, 1, 1) + # cubic_alpha shape: (batch_size, 1, 1). it will be negative. linear_alpha = -(1-beta1) - cubic_alpha # will be negative. moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. - linear_alpha) + moving_grad_precon.mul_(1. + linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP @@ -203,7 +203,7 @@ def cubic_decay_step(group, state, grad): # rather than as a gradient. it is going to be very close to: # negative_update = moving_grad_precon / invP # but we also update the preconditioner. Note: practically speaking we are multiplying - # by the same thing twice though. + # by the same thing twice, i.e. dividing "grad" twice by invP. negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what @@ -434,7 +434,7 @@ def _test_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.017 + lr = 0.015 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 From c4c887959dbf69f7515338323cb338e4e17c41a8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Apr 2026 13:55:22 +0800 Subject: [PATCH 1079/1191] Take zapformer.py from 3045, reducing min of final residual_scale from .5 to 1/num_layers. --- egs/librispeech/ASR/zapformer/zapformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 0457210ee2..f45c250d6c 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -869,7 +869,7 @@ def forward( aux_loss_scale=aux_loss_scale/num_layers, ) residual_scale = limit_param_value(self.residual_scales[i + 1], - min=0.0 if i + 1 < num_layers else 0.5, + min=0.0 if i + 1 < num_layers else min(0.5, 1. / num_layers), max=1.0) src_with_bypass = src_with_bypass + residual_scale * src From 94f7bbd624e5828b95cb6d0e07a910338ea5b1d7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Apr 2026 16:25:30 +0800 Subject: [PATCH 1080/1191] Add random embedding projections to tensorboard. --- .../ASR/zapformer/batched_rubik.py | 37 ++++++++++++++++++- egs/librispeech/ASR/zapformer/train.py | 5 ++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index aab5086a75..0f9f6c498d 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -443,6 +443,7 @@ def __init__( adam_beta1=0.98, adam_beta2=0.98, scale_momentum=0.95, + tb_writer=None, ): defaults = dict( @@ -464,6 +465,7 @@ def __init__( super(BatchedRubik, self).__init__(param_groups, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names + self.tb_writer = tb_writer def _get_names_of_parameters( self, params_or_named_params @@ -592,6 +594,10 @@ def step(self, closure=None): batch = True + # accumulate a random projection of the parameters in the tensorboard for purposes of graphing. + generator = None + rand_proj = 0.0 + for group, group_params_names in zip(self.param_groups, self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches: @@ -612,11 +618,40 @@ def step(self, closure=None): else: p += scaling_step(group, p.detach(), state, grad) + + if self.tb_writer is not None: + with torch.no_grad(): + generator, rand_proj = self._accumulate_random_projection(generator, rand_proj, p) + state["step"] = cur_step + 1 - return loss + if self.tb_writer is not None: + rand_proj = rand_proj.to('cpu') + for i in range(rand_proj.numel()): + self.tb_writer.add_scalar(f'train/rand_proj{i+1}', rand_proj[i], cur_step) + return loss + def _accumulate_random_projection(self, + generator: Optional[torch.Generator], + rand_proj: Union[float, Tensor], + p: Tensor): + num_lines = 2 + # plot two separate lines. Caution: don't increase this to a large number. Tensorboard + # relies on an extremely slow mechanism based on python semaphores or something like + # that, to add items to plot, and it can only handle a certain rate of these scalars. + # adding any more + if generator is None: + generator = torch.Generator(device=p.device) + generator.manual_seed(100) # must have same seed each time to make the plot meaningful + # this is called at the beginning of each step. + if rand_proj is 0.0: + rand_proj = torch.zeros(num_lines, device=p.device) + + for i in range(num_lines): + proj = torch.randn(*p.shape, generator=generator, device=p.device) + rand_proj[i] += (p * proj).sum() + return generator, rand_proj def _test_batched_rubik(hidden_dim: int): import timeit diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 28b147841b..63b1803791 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -76,7 +76,7 @@ # the try-pass blocks around imports are to reduce the chance of failures when running multiple code # versions in parallel; later, these can be removed. try: - from batched_rubik import BatchedRubik as Rubik + from batched_rubik import BatchedRubik # could also have done: # from rubik import Rubik except: @@ -1358,12 +1358,13 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = Rubik( + optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, direct=0.15, cubic_decay_proportion=0.8, beta1=0.995, + tb_writer=tb_writer, ) From 09667e601ae45f20bfbada2ab7765173f8879312 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Apr 2026 19:22:22 +0800 Subject: [PATCH 1081/1191] Print random state for debugging consistency. --- egs/librispeech/ASR/zapformer/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 63b1803791..b03df8b85f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1222,6 +1222,10 @@ def save_bad_model(suffix: str = ""): f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) + logging.info( + f"rng_state={torch.cuda.get_rng_state()}" + ) + if tb_writer is not None: tb_writer.add_scalar( From de570939873613bf3aa1935d010af00b6b676a68 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Apr 2026 20:36:24 +0800 Subject: [PATCH 1082/1191] Print augmented features sum to check consistency --- egs/librispeech/ASR/zapformer/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index b03df8b85f..e3619acc6c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -959,6 +959,12 @@ def compute_loss( with torch.amp.autocast('cuda', enabled=False): features = specaug(features.to(torch.float), feature_lens) + + if batch_idx_train % 50 == 0: + logging.info( + f"rng_state={torch.cuda.get_rng_state()}, features-sum={features.sum()}" + ) + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( x=features, @@ -1222,9 +1228,6 @@ def save_bad_model(suffix: str = ""): f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) - logging.info( - f"rng_state={torch.cuda.get_rng_state()}" - ) if tb_writer is not None: From fa3194d9697326495d0cf55cfb011d3ea40ca66a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Apr 2026 22:50:26 +0800 Subject: [PATCH 1083/1191] project full grad, not just sign --- .../ASR/zapformer/batched_rubik.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 0f9f6c498d..2bd84c2c38 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -595,8 +595,7 @@ def step(self, closure=None): batch = True # accumulate a random projection of the parameters in the tensorboard for purposes of graphing. - generator = None - rand_proj = 0.0 + generator, param_proj, grad_proj = None, None, None for group, group_params_names in zip(self.param_groups, self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches: @@ -621,14 +620,17 @@ def step(self, closure=None): if self.tb_writer is not None: with torch.no_grad(): - generator, rand_proj = self._accumulate_random_projection(generator, rand_proj, p) + generator, param_proj = self._accumulate_random_projection(generator, param_proj, p) + generator, grad_proj = self._accumulate_random_projection(generator, grad_proj, grad) state["step"] = cur_step + 1 if self.tb_writer is not None: - rand_proj = rand_proj.to('cpu') - for i in range(rand_proj.numel()): - self.tb_writer.add_scalar(f'train/rand_proj{i+1}', rand_proj[i], cur_step) + param_proj = param_proj.to('cpu') + grad_proj = grad_proj.to('cpu') + for i in range(param_proj.numel()): + self.tb_writer.add_scalar(f'train/param_proj{i+1}', param_proj[i], cur_step) + self.tb_writer.add_scalar(f'train/grad_proj{i+1}', grad_proj[i], cur_step) return loss @@ -645,7 +647,7 @@ def _accumulate_random_projection(self, generator = torch.Generator(device=p.device) generator.manual_seed(100) # must have same seed each time to make the plot meaningful # this is called at the beginning of each step. - if rand_proj is 0.0: + if rand_proj is None: rand_proj = torch.zeros(num_lines, device=p.device) for i in range(num_lines): @@ -696,7 +698,11 @@ def _test_batched_rubik(hidden_dim: int): # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - optim = BatchedRubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999) + + from torch.utils.tensorboard import SummaryWriter + tb_writer = SummaryWriter(log_dir=f"tensorboard") + + optim = BatchedRubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999, tb_writer=tb_writer) num_epochs = 180 From f9d6db64c7785895948892342c4813e4e7d189e7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Apr 2026 10:42:50 +0800 Subject: [PATCH 1084/1191] Move param-random-proj code to debug_params() function, separate from optimizer. --- .../ASR/zapformer/batched_rubik.py | 56 ++++--------------- egs/librispeech/ASR/zapformer/train.py | 41 +++++++++++++- 2 files changed, 50 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 2bd84c2c38..e68e0cff77 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -293,8 +293,13 @@ def cubic_decay_step(group, state, grad): linear_alpha = -(1-beta1) - cubic_alpha # will be negative. + # the next line undoes the preconditioning so we can accumulate gradient + # stats in the "canonical basis" of the gradients, for consistency. + moving_grad_cubic_decay = moving_grad_precon * invP + moving_grad_linear_decay = moving_grad * beta1 + moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. + linear_alpha) # equiv to: moving_grad_precon.add_(moving_grad_precon, alpha=linear_alpha) + moving_grad_precon.mul_(1. - linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP @@ -304,7 +309,7 @@ def cubic_decay_step(group, state, grad): # rather than as a gradient. it is going to be very close to: # negative_update = moving_grad_precon / invP # but we also update the preconditioner. Note: practically speaking we are multiplying - # by the same thing twice, i.e. dividing "grad" twice by invP. + # by the same thing twice though. negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what @@ -443,7 +448,6 @@ def __init__( adam_beta1=0.98, adam_beta2=0.98, scale_momentum=0.95, - tb_writer=None, ): defaults = dict( @@ -465,7 +469,6 @@ def __init__( super(BatchedRubik, self).__init__(param_groups, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names - self.tb_writer = tb_writer def _get_names_of_parameters( self, params_or_named_params @@ -594,9 +597,6 @@ def step(self, closure=None): batch = True - # accumulate a random projection of the parameters in the tensorboard for purposes of graphing. - generator, param_proj, grad_proj = None, None, None - for group, group_params_names in zip(self.param_groups, self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches: @@ -617,43 +617,11 @@ def step(self, closure=None): else: p += scaling_step(group, p.detach(), state, grad) - - if self.tb_writer is not None: - with torch.no_grad(): - generator, param_proj = self._accumulate_random_projection(generator, param_proj, p) - generator, grad_proj = self._accumulate_random_projection(generator, grad_proj, grad) - state["step"] = cur_step + 1 - if self.tb_writer is not None: - param_proj = param_proj.to('cpu') - grad_proj = grad_proj.to('cpu') - for i in range(param_proj.numel()): - self.tb_writer.add_scalar(f'train/param_proj{i+1}', param_proj[i], cur_step) - self.tb_writer.add_scalar(f'train/grad_proj{i+1}', grad_proj[i], cur_step) - return loss - def _accumulate_random_projection(self, - generator: Optional[torch.Generator], - rand_proj: Union[float, Tensor], - p: Tensor): - num_lines = 2 - # plot two separate lines. Caution: don't increase this to a large number. Tensorboard - # relies on an extremely slow mechanism based on python semaphores or something like - # that, to add items to plot, and it can only handle a certain rate of these scalars. - # adding any more - if generator is None: - generator = torch.Generator(device=p.device) - generator.manual_seed(100) # must have same seed each time to make the plot meaningful - # this is called at the beginning of each step. - if rand_proj is None: - rand_proj = torch.zeros(num_lines, device=p.device) - - for i in range(num_lines): - proj = torch.randn(*p.shape, generator=generator, device=p.device) - rand_proj[i] += (p * proj).sum() - return generator, rand_proj + def _test_batched_rubik(hidden_dim: int): import timeit @@ -694,15 +662,11 @@ def _test_batched_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.015 + lr = 0.017 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - - from torch.utils.tensorboard import SummaryWriter - tb_writer = SummaryWriter(log_dir=f"tensorboard") - - optim = BatchedRubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999, tb_writer=tb_writer) + optim = BatchedRubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999) num_epochs = 180 diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index e3619acc6c..d3773da159 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -107,6 +107,7 @@ save_checkpoint_with_global_batch_idx, update_averaged_model, ) +import torch.distributed as dist from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error @@ -680,6 +681,44 @@ def get_params() -> AttributeDict: return params +def debug_params(model: Union[nn.Module, DDP], + tb_writer: Optional[SummaryWriter] = None, + step: int = 0, + seed: int = 1): # can try different seeds if you want. + if isinstance(model, DDP): + model = model.module + device = next(model.parameters()).device + generator = torch.Generator(device=device) + generator.manual_seed(seed) + with torch.no_grad(): + param_proj = torch.tensor(0.0, device=device) + grad_proj = torch.tensor(0.0, device=device) + for p in model.parameters(): + proj = torch.randn(p.shape, generator=generator, device=p.device) + param_proj = param_proj + (p * proj).sum() + try: + grad_proj = grad_proj + (p.grad * proj).sum() + except AttributeError: + pass + + def dump(proj: Tensor, name: str): + proj_min = proj.clone() + proj_max = proj.clone() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(proj_min, op=dist.ReduceOp.MIN) + dist.all_reduce(proj_max, op=dist.ReduceOp.MAX) + dist.all_reduce(proj, op=dist.ReduceOp.SUM) + proj = proj / dist.get_world_size() + proj_diff = proj_max - proj_min + if tb_writer is not None: + tb_writer.add_scalar(name + '_diff', proj_diff.item(), step) + if tb_writer is not None: + tb_writer.add_scalar(name, proj.item(), step) + dump(param_proj, f'train/param_proj{seed}') + dump(grad_proj, f'train/grad_proj{seed}') + + + def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) @@ -1151,6 +1190,7 @@ def save_bad_model(suffix: str = ""): scheduler.set_batch(batch_idx) # sets batch-count within the epoch, and sets the LRs. scaler.step(optimizer) scaler.update() + debug_params(model, tb_writer, params.batch_idx_train, seed=1) optimizer.zero_grad() except Exception as e: logging.info(f"Caught exception: {e}.") @@ -1371,7 +1411,6 @@ def run(rank, world_size, args): direct=0.15, cubic_decay_proportion=0.8, beta1=0.995, - tb_writer=tb_writer, ) From da575d85165c254297346e975b0db38429cfbbb6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Apr 2026 11:34:32 +0800 Subject: [PATCH 1085/1191] Restore correlation limiter but with enormous limit, of 0.25. --- egs/librispeech/ASR/zapformer/zapformer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index f45c250d6c..e9699e128e 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -600,7 +600,11 @@ def __init__( self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) #power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) - #self.correlation_limiter = CorrelationLimiter(limit=(1. / (embed_dim ** power))) + #limit = (1. / (embed_dim ** power))) + limit = 0.25 # this is very enormous limit on correlations, it's just to prevent divergence + # and bad parameter locations from which it's impossible for the optimizer to escape. i.e. + # it should impose no real limitation on "normal" training runs. + self.correlation_limiter = CorrelationLimiter(limit=limit) self.self_attn = MultiheadRelPosGatedSelfAttention( embed_dim, @@ -649,9 +653,9 @@ def forward( """ src_orig = src - #src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), - # 2. * aux_loss_scale, mask=src_key_padding_mask), - #None) + src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + 2. * aux_loss_scale, mask=src_key_padding_mask), + None) src_pre_ff1 = src From 912c605219af08aca96035b9c5d30d7debef0351 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Apr 2026 11:35:47 +0800 Subject: [PATCH 1086/1191] Make debug printout of correlations more frequent and print the limit. --- egs/librispeech/ASR/zapformer/zapformer_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index d98ec3c253..434ea104de 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -670,9 +670,9 @@ def norm(x: Tensor): correlation = (S1 * S2).mean() loss = (correlation - ctx.limit).relu() - if random.random() < 0.0001: + if random.random() < 0.001: logging.info( - f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, loss={loss}" + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, limit={ctx.limit}, loss={loss}" ) loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) From 655dca37588836971dd3bbf44bf78bebbe5cfb40 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Apr 2026 14:38:26 +0800 Subject: [PATCH 1087/1191] Introduce adafactor_beta1=0.9, add conventional momentum into direct term. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index e68e0cff77..182644d03e 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -223,6 +223,7 @@ def no_momentum_step(group, state, grad): step = state["step"] lr = group["lr"] eps = group["eps"] + adafactor_beta1 = min(0.9, step / (step + 1)) # the following modification to beta2 warms up beta2 gradually. # For the first step we just take the current stats; this is similar to @@ -233,13 +234,20 @@ def no_momentum_step(group, state, grad): try: row_stats = state["direct_row_stats"] col_stats = state["direct_col_stats"] + adafactor_momentum = state["adafactor_momentum"] except KeyError: row_stats = torch.zeros(batch_size, rows, 1, device=grad.device, dtype=grad.dtype) col_stats = torch.zeros(batch_size, 1, cols, device=grad.device, dtype=grad.dtype) + adafactor_momentum = torch.zeros(batch_size, rows, cols, device=grad.device, dtype=grad.dtype) state["direct_row_stats"] = row_stats state["direct_col_stats"] = col_stats + state["adafactor_momentum"] = adafactor_momentum - return -lr * normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + norm_grad = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + adafactor_momentum.mul_(adafactor_beta1) + adafactor_momentum.add_(norm_grad, alpha=1.-adafactor_beta1) + + return -lr * adafactor_momentum def cubic_decay_step(group, state, grad): From f640f6ea9d064057b4acdbdf912eeb11ae8a827a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Apr 2026 14:39:15 +0800 Subject: [PATCH 1088/1191] Increase direct from 0.15 to 0.25 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d3773da159..56c0767f88 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1408,7 +1408,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.15, + direct=0.25, cubic_decay_proportion=0.8, beta1=0.995, ) From 0895abb06c0e57f59e668eb9aa2c5c2c0b5a3a14 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Apr 2026 18:43:09 +0800 Subject: [PATCH 1089/1191] Reduce direct scale from .25 to .15 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 56c0767f88..d3773da159 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1408,7 +1408,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.25, + direct=0.15, cubic_decay_proportion=0.8, beta1=0.995, ) From f325df45b14e8d5a323a69428d7841f6bf24a317 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Apr 2026 16:04:41 +0800 Subject: [PATCH 1090/1191] Change adfactor_beta1 to -0.5 and direct to 0.05 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 182644d03e..c94f567a5a 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -223,7 +223,7 @@ def no_momentum_step(group, state, grad): step = state["step"] lr = group["lr"] eps = group["eps"] - adafactor_beta1 = min(0.9, step / (step + 1)) + adafactor_beta1 = 0. if step == 0 else -0.5 # the following modification to beta2 warms up beta2 gradually. # For the first step we just take the current stats; this is similar to diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d3773da159..eddca882f5 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1408,7 +1408,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.15, + direct=0.05, cubic_decay_proportion=0.8, beta1=0.995, ) From d0cae9b77b433df4f0a12b61497487f542994e39 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Apr 2026 19:56:50 +0800 Subject: [PATCH 1091/1191] Print debug_grad less frequently. --- egs/librispeech/ASR/zapformer/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index eddca882f5..31e7306276 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1190,7 +1190,8 @@ def save_bad_model(suffix: str = ""): scheduler.set_batch(batch_idx) # sets batch-count within the epoch, and sets the LRs. scaler.step(optimizer) scaler.update() - debug_params(model, tb_writer, params.batch_idx_train, seed=1) + if params.batch_idx_train < 2000 or params.batch_idx_train % 1000 < 100: + debug_params(model, tb_writer, params.batch_idx_train, seed=1) optimizer.zero_grad() except Exception as e: logging.info(f"Caught exception: {e}.") From 9462cc5557c2442e41ff3224ff46d3e33600e39c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 12:53:26 +0800 Subject: [PATCH 1092/1191] Decrease direct from .05 to .01, adafactor_beta1 from -0.5 to -0.9 and introduce warmup to it over 4k steps. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index c94f567a5a..921ddb2ec6 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -223,7 +223,7 @@ def no_momentum_step(group, state, grad): step = state["step"] lr = group["lr"] eps = group["eps"] - adafactor_beta1 = 0. if step == 0 else -0.5 + adafactor_beta1 = -0.9 * min(1, step / 4000) # the following modification to beta2 warms up beta2 gradually. # For the first step we just take the current stats; this is similar to diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 31e7306276..c5336ec456 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1409,7 +1409,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.05, + direct=0.01, cubic_decay_proportion=0.8, beta1=0.995, ) From 8428671d4827d4afc7cda7b5e50c1bc644d6f40e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 13:48:18 +0800 Subject: [PATCH 1093/1191] take debugging-only changes to train.py from 3078. --- egs/librispeech/ASR/zapformer/train.py | 97 ++++++++++++++++---------- 1 file changed, 60 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index c5336ec456..296f729a56 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -681,41 +681,63 @@ def get_params() -> AttributeDict: return params -def debug_params(model: Union[nn.Module, DDP], - tb_writer: Optional[SummaryWriter] = None, - step: int = 0, - seed: int = 1): # can try different seeds if you want. - if isinstance(model, DDP): - model = model.module - device = next(model.parameters()).device - generator = torch.Generator(device=device) - generator.manual_seed(seed) - with torch.no_grad(): - param_proj = torch.tensor(0.0, device=device) - grad_proj = torch.tensor(0.0, device=device) - for p in model.parameters(): - proj = torch.randn(p.shape, generator=generator, device=p.device) - param_proj = param_proj + (p * proj).sum() - try: - grad_proj = grad_proj + (p.grad * proj).sum() - except AttributeError: - pass - - def dump(proj: Tensor, name: str): - proj_min = proj.clone() - proj_max = proj.clone() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(proj_min, op=dist.ReduceOp.MIN) - dist.all_reduce(proj_max, op=dist.ReduceOp.MAX) - dist.all_reduce(proj, op=dist.ReduceOp.SUM) - proj = proj / dist.get_world_size() - proj_diff = proj_max - proj_min - if tb_writer is not None: - tb_writer.add_scalar(name + '_diff', proj_diff.item(), step) - if tb_writer is not None: - tb_writer.add_scalar(name, proj.item(), step) - dump(param_proj, f'train/param_proj{seed}') - dump(grad_proj, f'train/grad_proj{seed}') +class ParamPlotter: + def __init__(self, + model: Union[nn.Module, DDP], + tb_writer: Optional[SummaryWriter], + period: int = 50): + if isinstance(model, DDP): + model = model.module + self.model = model + self.tb_writer = tb_writer + device = next(model.parameters()).device + self.device = device + self.period = period + self.grad_proj = torch.tensor(0.0, device=device) + + def step(self, batch_idx_train: int): + if batch_idx_train % self.period > 1: + return + + generator = torch.Generator(device=self.device) + generator.manual_seed(1) + + + with torch.no_grad(): + param_proj = torch.tensor(0.0, device=self.device) + grad_proj = torch.tensor(0.0, device=self.device) + for p in self.model.parameters(): + proj = torch.randn(p.shape, generator=generator, device=self.device) + param_proj = param_proj + (p * proj).sum() + try: + grad_proj = grad_proj + (p.grad * proj).sum() + except AttributeError: + pass + + tb_writer = self.tb_writer + def dump(proj: Tensor, name: str): + proj_min = proj.clone() + proj_max = proj.clone() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(proj_min, op=dist.ReduceOp.MIN) + dist.all_reduce(proj_max, op=dist.ReduceOp.MAX) + dist.all_reduce(proj, op=dist.ReduceOp.SUM) + proj = proj / dist.get_world_size() + proj_diff = proj_max - proj_min + if tb_writer is not None: + tb_writer.add_scalar(name + '_diff', proj_diff.item(), batch_idx_train) + if tb_writer is not None: + tb_writer.add_scalar(name, proj.item(), batch_idx_train) + if batch_idx_train % self.period == 0: + dump(param_proj, f'train/param_proj') + dump(grad_proj, f'train/grad_proj') + self.grad_proj = grad_proj + elif tb_writer is not None: + assert batch_idx_train % self.period == 1, batch_idx_train + tb_writer.add_scalar('train/grad_same_sign', (grad_proj * self.grad_proj).sign(), batch_idx_train) + + + @@ -1142,6 +1164,8 @@ def train_one_epoch( saved_bad_model = False + param_plotter = ParamPlotter(model, tb_writer, period=50) + def get_scaler_scale(): if params.use_autocast and scaler._scale is not None: return scaler._scale.item() @@ -1190,8 +1214,7 @@ def save_bad_model(suffix: str = ""): scheduler.set_batch(batch_idx) # sets batch-count within the epoch, and sets the LRs. scaler.step(optimizer) scaler.update() - if params.batch_idx_train < 2000 or params.batch_idx_train % 1000 < 100: - debug_params(model, tb_writer, params.batch_idx_train, seed=1) + param_plotter.step(params.batch_idx_train) optimizer.zero_grad() except Exception as e: logging.info(f"Caught exception: {e}.") From ad80d42e5690bf2d137b2b9f7ef3fbe710a35b75 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 13:56:34 +0800 Subject: [PATCH 1094/1191] Plot grad_proj for 50 out of every 1000 steps so we can get a sense for how stable the oscillations are. --- egs/librispeech/ASR/zapformer/train.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 296f729a56..a535f4ea7a 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -696,7 +696,12 @@ def __init__(self, self.grad_proj = torch.tensor(0.0, device=device) def step(self, batch_idx_train: int): - if batch_idx_train % self.period > 1: + # in addition to plotting param_proj and grad_proj and grad_proj_sign every "period" steps, + # plot grad_proj for the first 50 out of every 1000 steps; this will give us a sense of how + # stable the oscillations are. + dense_period = 1000 + dense_length = 50 + if batch_idx_train % self.period > 1 and batch_idx_train % dense_period > dense_length: return generator = torch.Generator(device=self.device) @@ -730,13 +735,11 @@ def dump(proj: Tensor, name: str): tb_writer.add_scalar(name, proj.item(), batch_idx_train) if batch_idx_train % self.period == 0: dump(param_proj, f'train/param_proj') - dump(grad_proj, f'train/grad_proj') self.grad_proj = grad_proj - elif tb_writer is not None: - assert batch_idx_train % self.period == 1, batch_idx_train - tb_writer.add_scalar('train/grad_same_sign', (grad_proj * self.grad_proj).sign(), batch_idx_train) - - + if batch_idx_train % self.period == 1 and tb_writer is not None: + tb_writer.add_scalar('train/grad_same_sign', (grad_proj * self.grad_proj).sign(), batch_idx_train) + if (batch_idx_train % dense_period < dense_length or batch_idx_train % self.period == 0) and tb_writer is not None: + tb_writer.add_scalar('train/grad_proj', grad_proj, batch_idx_train) From 673c6b66f6171a02d4465e6252c68fde4dbfdb86 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 14:01:44 +0800 Subject: [PATCH 1095/1191] Make direct interpreted as a learning rate, not a scale, set it to 0.00015. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 7 ++++--- egs/librispeech/ASR/zapformer/train.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 921ddb2ec6..0101352b34 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -247,7 +247,7 @@ def no_momentum_step(group, state, grad): adafactor_momentum.mul_(adafactor_beta1) adafactor_momentum.add_(norm_grad, alpha=1.-adafactor_beta1) - return -lr * adafactor_momentum + return adafactor_momentum def cubic_decay_step(group, state, grad): @@ -332,7 +332,8 @@ def cubic_decay_step(group, state, grad): if direct == 0.0: ans = -lr * negative_update else: - ans = ((1. - direct) * -lr) * negative_update + direct * no_momentum_step(group, state, grad) + # now interpret direct as a fixed learning rate, not a scale on the learning rate. + ans = -lr * negative_update + -direct * no_momentum_step(group, state, grad) return ans.reshape(orig_shape) @@ -674,7 +675,7 @@ def _test_batched_rubik(hidden_dim: int): # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - optim = BatchedRubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999) + optim = BatchedRubik(m.parameters(), lr=lr, direct=0.0001, beta1=0.999) num_epochs = 180 diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index a535f4ea7a..6d31c12f5f 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.01, + direct=0.00015, cubic_decay_proportion=0.8, beta1=0.995, ) From 7f7b85f808bd36545284637042295d5d21f282a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 17:15:32 +0800 Subject: [PATCH 1096/1191] Make adafactor update fully cancel, subtract it all on the next step via adafactor_beta1=0.0. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 0101352b34..19c22e92d9 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -223,7 +223,7 @@ def no_momentum_step(group, state, grad): step = state["step"] lr = group["lr"] eps = group["eps"] - adafactor_beta1 = -0.9 * min(1, step / 4000) + adafactor_beta1 = 0.0 # the following modification to beta2 warms up beta2 gradually. # For the first step we just take the current stats; this is similar to @@ -244,10 +244,12 @@ def no_momentum_step(group, state, grad): state["adafactor_momentum"] = adafactor_momentum norm_grad = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + + prev_momentum = adafactor_momentum.clone() adafactor_momentum.mul_(adafactor_beta1) adafactor_momentum.add_(norm_grad, alpha=1.-adafactor_beta1) - return adafactor_momentum + return norm_grad - prev_momentum # cancels it out over the long term so we're just adding noise/instability def cubic_decay_step(group, state, grad): From ae82911a3adc65c765c2ea2ea9d258450b53d8c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 17:18:02 +0800 Subject: [PATCH 1097/1191] Increase direct scale from .00015 to .0015. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 6d31c12f5f..ae4da11f18 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.00015, + direct=0.0015, cubic_decay_proportion=0.8, beta1=0.995, ) From 6c7bf06d0a5951f902f610c6b877c86027be412f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 18:04:26 +0800 Subject: [PATCH 1098/1191] Reduce beta2 used in no_momentum_step from 0.98 to 0.9. --- egs/librispeech/ASR/zapformer/rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 9920640ee3..a829cf1f92 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -121,7 +121,7 @@ def no_momentum_step(group, state, grad): # the following modification to beta2 warms up beta2 gradually. # For the first step we just take the current stats; this is similar to # a sign-only update. - beta2 = min(group["beta2"], 1. - 1. / (1. + 0.2 * step)) + beta2 = min(0.9, 1. - 1. / (1. + 0.2 * step)) (rows, cols) = grad.shape From a500f0e8e976854017ef875afa197bed4b8b91a0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 19:30:43 +0800 Subject: [PATCH 1099/1191] Reduce beta2 used in no_momentum_step from 0.9 to 0.0 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 19c22e92d9..d7b1e80537 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -225,10 +225,11 @@ def no_momentum_step(group, state, grad): eps = group["eps"] adafactor_beta1 = 0.0 - # the following modification to beta2 warms up beta2 gradually. - # For the first step we just take the current stats; this is similar to - # a sign-only update. - beta2 = min(group["beta2"], 1. - 1. / (1. + 0.2 * step)) + ## the following modification to beta2 warms up beta2 gradually. + ## For the first step we just take the current stats; this is similar to + ## a sign-only update. + #beta2 = min(group["beta2"], 1. - 1. / (1. + 0.2 * step)) + beta2 = 0.0 # so actually just immediately normalize. (batch_size, rows, cols) = grad.shape try: From ccf5cfbac743cd2f7a53d6ce7bff1e9daedad8e5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 19:45:18 +0800 Subject: [PATCH 1100/1191] Warm up cancellation of direct gradient over 4k batches. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index d7b1e80537..e54d653798 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -224,6 +224,8 @@ def no_momentum_step(group, state, grad): lr = group["lr"] eps = group["eps"] adafactor_beta1 = 0.0 + warm_steps = 4000 # warm up cancellation over 4k steps + cancellation_scale = min(1.0, step / warm_steps) ## the following modification to beta2 warms up beta2 gradually. ## For the first step we just take the current stats; this is similar to @@ -248,7 +250,7 @@ def no_momentum_step(group, state, grad): prev_momentum = adafactor_momentum.clone() adafactor_momentum.mul_(adafactor_beta1) - adafactor_momentum.add_(norm_grad, alpha=1.-adafactor_beta1) + adafactor_momentum.add_(norm_grad, alpha=(1.-adafactor_beta1) * cancellation_scale) return norm_grad - prev_momentum # cancels it out over the long term so we're just adding noise/instability From 3eff0015c0608e9f108e21228cff410edbb922a0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 19:50:54 +0800 Subject: [PATCH 1101/1191] Reduce direct=0.0015 to direct=0.0005. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index ae4da11f18..54df373a87 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.0015, + direct=0.0005, cubic_decay_proportion=0.8, beta1=0.995, ) From 730220ef3c394af225f72b436c7055b352ba3c77 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Apr 2026 23:10:43 +0800 Subject: [PATCH 1102/1191] Change adafactor_beta1 from 0.0 to -0.5 (beta1 used in no_momentum_step) --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index e54d653798..bb93ceb696 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -223,7 +223,7 @@ def no_momentum_step(group, state, grad): step = state["step"] lr = group["lr"] eps = group["eps"] - adafactor_beta1 = 0.0 + adafactor_beta1 = -0.5 warm_steps = 4000 # warm up cancellation over 4k steps cancellation_scale = min(1.0, step / warm_steps) From 313bf5e6e931f046123088039b01cc8c5a226d7e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 25 Apr 2026 13:29:44 +0800 Subject: [PATCH 1103/1191] Have direct grad immediately normalized and have it warm down over 5k batches. --- .../ASR/zapformer/batched_rubik.py | 59 +++++-------------- egs/librispeech/ASR/zapformer/train.py | 2 +- 2 files changed, 15 insertions(+), 46 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index bb93ceb696..c8d289c020 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -213,46 +213,13 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): return x / (col_stats.sqrt() + eps) -def no_momentum_step(group, state, grad): - # computes an update direction using magnitude normalization but no momentum - # (no beta1, in adam terminology). grad is assumed to have exactly three - # dimensions (grad.ndim == 3), representing (batch_size, rows, cols). - # the grad is normalized using adafactor-like - # row and column statistics, but done sequentially over first rows and then - # columns - step = state["step"] - lr = group["lr"] - eps = group["eps"] - adafactor_beta1 = -0.5 - warm_steps = 4000 # warm up cancellation over 4k steps - cancellation_scale = min(1.0, step / warm_steps) - - ## the following modification to beta2 warms up beta2 gradually. - ## For the first step we just take the current stats; this is similar to - ## a sign-only update. - #beta2 = min(group["beta2"], 1. - 1. / (1. + 0.2 * step)) - beta2 = 0.0 # so actually just immediately normalize. - - (batch_size, rows, cols) = grad.shape - try: - row_stats = state["direct_row_stats"] - col_stats = state["direct_col_stats"] - adafactor_momentum = state["adafactor_momentum"] - except KeyError: - row_stats = torch.zeros(batch_size, rows, 1, device=grad.device, dtype=grad.dtype) - col_stats = torch.zeros(batch_size, 1, cols, device=grad.device, dtype=grad.dtype) - adafactor_momentum = torch.zeros(batch_size, rows, cols, device=grad.device, dtype=grad.dtype) - state["direct_row_stats"] = row_stats - state["direct_col_stats"] = col_stats - state["adafactor_momentum"] = adafactor_momentum +def norm_rows_and_cols(x, eps): + row_denom = (x ** 2).mean(dim=2, keepdim=True).sqrt() + eps + x = x / row_denom + col_denom = (x ** 2).mean(dim=1, keepdim=True).sqrt() + eps + x = x / col_denom + return x - norm_grad = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - - prev_momentum = adafactor_momentum.clone() - adafactor_momentum.mul_(adafactor_beta1) - adafactor_momentum.add_(norm_grad, alpha=(1.-adafactor_beta1) * cancellation_scale) - - return norm_grad - prev_momentum # cancels it out over the long term so we're just adding noise/instability def cubic_decay_step(group, state, grad): @@ -262,7 +229,10 @@ def cubic_decay_step(group, state, grad): beta_ceil = 1. - 1. / (10. + 0.2 * step) beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) - direct = group["direct"] # scale on non-momentum step + direct_batches = 5000 # only use direct grad for first 5k batches. + direct = group["direct"] * max(0, 1. - step / direct_batches) # scale on non-momentum step, helpful for warmup + + cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion @@ -334,11 +304,10 @@ def cubic_decay_step(group, state, grad): negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - if direct == 0.0: - ans = -lr * negative_update - else: - # now interpret direct as a fixed learning rate, not a scale on the learning rate. - ans = -lr * negative_update + -direct * no_momentum_step(group, state, grad) + ans = -lr * negative_update + + if direct != 0.0: + ans = ans - direct * norm_rows_and_cols(grad, eps) return ans.reshape(orig_shape) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 54df373a87..f6415398a9 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.0005, + direct=0.001, cubic_decay_proportion=0.8, beta1=0.995, ) From 6815809927f4958619f4e8bbd9160a45319ddcd4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 26 Apr 2026 11:26:01 +0800 Subject: [PATCH 1104/1191] Have direct lr warm down to 0.2 of its initial value, not zero. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index c8d289c020..80c84f69db 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -230,7 +230,7 @@ def cubic_decay_step(group, state, grad): beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) direct_batches = 5000 # only use direct grad for first 5k batches. - direct = group["direct"] * max(0, 1. - step / direct_batches) # scale on non-momentum step, helpful for warmup + direct = group["direct"] * max(0.2, 1. - step / direct_batches) # scale on non-momentum step, helpful for warmup cubic_decay_proportion = group["cubic_decay_proportion"] From c9b8fa8047335f5422313604069b1a040b6e0089 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 27 Apr 2026 11:51:00 +0800 Subject: [PATCH 1105/1191] Change batched_rubik to ignore direct grad term and just use nesterov modification. --- .../ASR/zapformer/batched_rubik.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 80c84f69db..3239b66ac1 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -213,14 +213,6 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): return x / (col_stats.sqrt() + eps) -def norm_rows_and_cols(x, eps): - row_denom = (x ** 2).mean(dim=2, keepdim=True).sqrt() + eps - x = x / row_denom - col_denom = (x ** 2).mean(dim=1, keepdim=True).sqrt() + eps - x = x / col_denom - return x - - def cubic_decay_step(group, state, grad): lr = group["lr"] @@ -229,8 +221,6 @@ def cubic_decay_step(group, state, grad): beta_ceil = 1. - 1. / (10. + 0.2 * step) beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) - direct_batches = 5000 # only use direct grad for first 5k batches. - direct = group["direct"] * max(0.2, 1. - step / direct_batches) # scale on non-momentum step, helpful for warmup cubic_decay_proportion = group["cubic_decay_proportion"] @@ -266,6 +256,7 @@ def cubic_decay_step(group, state, grad): invP = row_denom * col_denom # inverse preconditioner P moving_grad_precon = moving_grad / invP # preconditioned moving_grad + cur_grad_precon = grad / invP # this step's contribution to moving_grad_precon, used for nesterov modification # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. @@ -293,6 +284,11 @@ def cubic_decay_step(group, state, grad): # negative_update = moving_grad_precon / invP # but we also update the preconditioner. Note: practically speaking we are multiplying # by the same thing twice though. + + nesterov = True + if nesterov: + moving_grad_precon = moving_grad_precon + cur_grad_precon + negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what @@ -300,15 +296,14 @@ def cubic_decay_step(group, state, grad): # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style # accumulator with beta equal to beta1. # This should make divergence less likely. + # we ignore nesterov modification for purposes of this formula, it should make little difference anyway + # if beta1 is close to 1. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) ans = -lr * negative_update - if direct != 0.0: - ans = ans - direct * norm_rows_and_cols(grad, eps) - return ans.reshape(orig_shape) @@ -421,7 +416,7 @@ def __init__( params, lr=1.2e-02, beta1=0.995, - direct=0.15, # scale on bypass of momentum (beta1) + direct=0.15, # Now ignored. cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, @@ -649,7 +644,7 @@ def _test_batched_rubik(hidden_dim: int): # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 - optim = BatchedRubik(m.parameters(), lr=lr, direct=0.0001, beta1=0.999) + optim = BatchedRubik(m.parameters(), lr=lr, beta1=0.999) num_epochs = 180 From 918f935be4c30e223e0eee85c9dd0ceaf8993711 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 27 Apr 2026 17:58:32 +0800 Subject: [PATCH 1106/1191] Remove lr_scale from OrthogonalLinear and replace it with weight_rms argument --- egs/librispeech/ASR/zapformer/zapformer.py | 10 ++++----- .../ASR/zapformer/zapformer_modules.py | 22 +++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index e9699e128e..9b559cf7d0 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -33,7 +33,6 @@ OrthogonalLinear, RmsNorm, SequenceNorm, - OrthogonalLinear, ScaledLinear, # just an initializer for Linear SwashR, ScaleLimiter, @@ -303,8 +302,8 @@ def compute_projection_overlap(self, verbose: bool = False): for i in range(N): for j in range(i): # multipying by lr_scale keeps the scale correct so they will be orthogonal - proj_i = self.encoders[i].proj.weight * self.encoders[i].proj.lr_scale - proj_j = self.encoders[j].proj.weight * self.encoders[j].proj.lr_scale + proj_i = self.encoders[i].proj.get_weight() + proj_j = self.encoders[j].proj.get_weight() if proj_i.shape[1] > proj_j.shape[1]: proj_i, proj_j = proj_j, proj_i # swap them in_dim_i = proj_i.shape[1] # now this is <= proj_j.shape[1] @@ -805,8 +804,9 @@ def __init__( super().__init__() # self.downsample will also reverse the downsampling operation for us afterward. - self.proj = OrthogonalLinear(dim, encoder_layer.embed_dim, - lr_scale=0.66, bias=False) + self.proj = OrthogonalLinear(dim, + encoder_layer.embed_dim, + bias=False) self.name = None self.layers = nn.ModuleList( diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index 434ea104de..f7845ea94c 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -514,10 +514,10 @@ class OrthogonalLinear(nn.Linear): Args: in_channels: number of input channels out_channels: number of output channels - lr_scale: we will scale the weight by this value before applying the orthogonal - constraint and using it; with most optimizers - this will have the effect of slowing down the learning by this factor because - the parameter value will be larger. + weight_rms: the rms value of the physical weights in self.weights; we rescale the weights + to achieve this while respecting the orthogonal constraint, as a way + of reducing the relative learning speed of this module. (larger weight_rms -> + slower learning, in general). bias: if True, include a bias term. penalty_scale: a scale on the penalty on non-orthogonality (this will be multiplied by the average-absolute-value of the @@ -528,30 +528,28 @@ class OrthogonalLinear(nn.Linear): def __init__(self, in_channels: int, out_channels: int, - lr_scale: float = 1.0, + weight_rms: float = 0.2, bias: bool = True, penalty_scale: float = 20.0, ): super().__init__(in_channels, out_channels, bias=bias) self.name = None self.penalty_scale = copy.deepcopy(penalty_scale) - self.lr_scale = lr_scale + self.weight_scale = (in_channels ** -0.5) / weight_rms with torch.no_grad(): - self.weight[:] = torch.randn(out_channels, in_channels) * (in_channels ** -0.5) * (1. / lr_scale) + self.weight[:] = torch.randn(out_channels, in_channels) * weight_rms if self.bias is not None: torch.nn.init.uniform_(self.bias, -0.01, 0.01) + def get_weight(self): + return self.weight * self.weight_scale def forward(self, x: Tensor, transpose: bool = False): # you can only use transpose=True if you used bias=False in initialization - weight = self.weight - lr_scale = self.lr_scale - if lr_scale != 1.0: - weight = weight * lr_scale + weight = self.get_weight() if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): weight = OrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) - if transpose: weight = weight.t() return torch.nn.functional.linear(x, weight, self.bias) From 7bf24f7e8ef91f8430878597900d399989e8102b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 27 Apr 2026 19:59:30 +0800 Subject: [PATCH 1107/1191] Introduce factor of 0.5 in decay that comes from the math; reduce scalar_scale and scale_limits to compensate. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 3239b66ac1..152fdc2821 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -349,7 +349,7 @@ def scaling_step(group, param, state, grad): scale_ratio = scale / old_scale - delta_scale = (scale_ratio * (1 - (lr ** 2))) - 1 + delta_scale = (scale_ratio * (1 - 0.5 * (lr ** 2))) - 1 return param * delta_scale + scale * delta @@ -420,9 +420,9 @@ def __init__( cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, - weight_scale_limits=(0.05, 0.25), - bias_scale_limits=(0.05, 0.25), - scalar_scale=0.075, + weight_scale_limits=(0.03, 0.15), + bias_scale_limits=(0.03, 0.15), + scalar_scale=0.05, adam_beta1=0.98, adam_beta2=0.98, scale_momentum=0.95, @@ -640,7 +640,7 @@ def _test_batched_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.017 + lr = 0.024 # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the # optimum parameters very exactly. Normally you want something more like the # defaults of beta1=0.995 and direct=0.15 From f9c6cb9b3475f3821e8f07fff6bc1825d48b2094 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 27 Apr 2026 20:13:38 +0800 Subject: [PATCH 1108/1191] Code cleanups, and propagate recent updates to batched_rubik.py to rubik.py --- .../ASR/zapformer/batched_rubik.py | 21 ++----- egs/librispeech/ASR/zapformer/rubik.py | 63 +++++-------------- egs/librispeech/ASR/zapformer/train.py | 1 - 3 files changed, 20 insertions(+), 65 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 152fdc2821..bc7835581d 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -267,28 +267,22 @@ def cubic_decay_step(group, state, grad): linear_alpha = -(1-beta1) - cubic_alpha # will be negative. - # the next line undoes the preconditioning so we can accumulate gradient - # stats in the "canonical basis" of the gradients, for consistency. - moving_grad_cubic_decay = moving_grad_precon * invP - moving_grad_linear_decay = moving_grad * beta1 - moving_grad_precon.add_(prod3 * cubic_alpha) moving_grad_precon.mul_(1. - linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP - # Now compute "negative_update" which is negative_update_precon multiplied again by the + nesterov = True + if nesterov: + moving_grad_precon = moving_grad_precon + cur_grad_precon + + # Now compute "negative_update" which is moving_grad_precon multiplied again by the # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update # rather than as a gradient. it is going to be very close to: # negative_update = moving_grad_precon / invP # but we also update the preconditioner. Note: practically speaking we are multiplying # by the same thing twice though. - - nesterov = True - if nesterov: - moving_grad_precon = moving_grad_precon + cur_grad_precon - negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what @@ -416,7 +410,6 @@ def __init__( params, lr=1.2e-02, beta1=0.995, - direct=0.15, # Now ignored. cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, @@ -431,7 +424,6 @@ def __init__( defaults = dict( lr=lr, beta1=beta1, - direct=direct, cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, @@ -641,9 +633,6 @@ def _test_batched_rubik(hidden_dim: int): ] lr = 0.024 - # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the - # optimum parameters very exactly. Normally you want something more like the - # defaults of beta1=0.995 and direct=0.15 optim = BatchedRubik(m.parameters(), lr=lr, beta1=0.999) num_epochs = 180 diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index a829cf1f92..3c54f7c1f4 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -108,34 +108,6 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): return x / (col_stats.sqrt() + eps) -def no_momentum_step(group, state, grad): - # computes an update direction using magnitude normalization but no momentum - # (no beta1, in adam terminology). grad is assumed to have exactly two - # dimensions (grad.ndim == 2). the grad is normalized using adafactor-like - # row and column statistics, but done sequentially over first rows and then - # columns - step = state["step"] - lr = group["lr"] - eps = group["eps"] - - # the following modification to beta2 warms up beta2 gradually. - # For the first step we just take the current stats; this is similar to - # a sign-only update. - beta2 = min(0.9, 1. - 1. / (1. + 0.2 * step)) - - (rows, cols) = grad.shape - - try: - row_stats = state["direct_row_stats"] - col_stats = state["direct_col_stats"] - except KeyError: - row_stats = torch.zeros(rows, 1, device=grad.device, dtype=grad.dtype) - col_stats = torch.zeros(1, cols, device=grad.device, dtype=grad.dtype) - state["direct_row_stats"] = row_stats - state["direct_col_stats"] = col_stats - - return -lr * normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - def cubic_decay_step(group, state, grad): lr = group["lr"] @@ -145,7 +117,6 @@ def cubic_decay_step(group, state, grad): beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) - direct = group["direct"] # scale on non-momentum step cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion @@ -178,14 +149,12 @@ def cubic_decay_step(group, state, grad): invP = row_denom * col_denom # inverse preconditioner P moving_grad_precon = moving_grad / invP # preconditioned moving_grad + cur_grad_precon = grad / invP # this step's contribution to moving_grad_precon, used for nesterov modification # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad_precon) - # next line similar to: - # moving_grad_precon.add_(prod3, alpha=-(1-beta1)) - # but with a precaution for divergence. cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) # cubic_alpha shape: (batch_size, 1, 1). it will be negative. @@ -193,11 +162,15 @@ def cubic_decay_step(group, state, grad): linear_alpha = -(1-beta1) - cubic_alpha # will be negative. moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. + linear_alpha) + moving_grad_precon.mul_(1. - linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP + nesterov = True + if nesterov: + moving_grad_precon = moving_grad_precon + cur_grad_precon + # Now compute "negative_update" which is negative_update_precon multiplied again by the # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update # rather than as a gradient. it is going to be very close to: @@ -211,14 +184,13 @@ def cubic_decay_step(group, state, grad): # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style # accumulator with beta equal to beta1. # This should make divergence less likely. + # we ignore nesterov modification for purposes of this formula, it should make little difference anyway + # if beta1 is close to 1. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean().sqrt() + eps)) - if direct == 0.0: - ans = -lr * negative_update - else: - ans = ((1. - direct) * -lr) * negative_update + direct * no_momentum_step(group, state, grad) + ans = -lr * negative_update return ans.reshape(orig_shape) @@ -264,7 +236,7 @@ def scaling_step(group, param, state, grad): scale_ratio = scale / old_scale - delta_scale = (scale_ratio * (1 - (lr ** 2))) - 1 + delta_scale = (scale_ratio * (1 - 0.5 * (lr ** 2))) - 1 return param * delta_scale + scale * delta @@ -320,13 +292,12 @@ def __init__( params, lr=1.2e-02, beta1=0.995, - direct=0.15, # scale on bypass of momentum (beta1) cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, - weight_scale_limits=(0.05, 0.25), - bias_scale_limits=(0.05, 0.25), - scalar_scale=0.075, + weight_scale_limits=(0.03, 0.15), + bias_scale_limits=(0.03, 0.15), + scalar_scale=0.05, adam_beta1=0.98, adam_beta2=0.98, scale_momentum=0.95, @@ -334,7 +305,6 @@ def __init__( defaults = dict( lr=lr, beta1=beta1, - direct=direct, cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, @@ -434,11 +404,8 @@ def _test_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.015 - # the very large beta1 and zero "direct" value is specifically for this test task, which approaches the - # optimum parameters very exactly. Normally you want something more like the - # defaults of beta1=0.995 and direct=0.15 - optim = Rubik(m.parameters(), lr=lr, direct=0.05, beta1=0.999) + lr = 0.024 + optim = Rubik(m.parameters(), lr=lr, beta1=0.999) num_epochs = 180 diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index f6415398a9..7aa4406cd0 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,6 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, - direct=0.001, cubic_decay_proportion=0.8, beta1=0.995, ) From e4e723490152c9c21b0a480216abb1d6e87797c8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Apr 2026 11:29:10 +0800 Subject: [PATCH 1109/1191] Change factor in beta_ceil from 0.2 to 0.1 for slower warmup of beta. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index bc7835581d..56ee6710a6 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -218,7 +218,7 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta_ceil = 1. - 1. / (10. + 0.2 * step) + beta_ceil = 1. - 1. / (10. + 0.1 * step) beta1 = min(group["beta1"], beta_ceil) beta2 = min(group["beta2"], beta_ceil) From 2cd1fc0fd8dd39c119da5d7f516d116ec9d9acd6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Apr 2026 12:12:15 +0800 Subject: [PATCH 1110/1191] Remove all special things from initialization and training of depthwise_conv: diagonal emphasis, lr_scale. --- egs/librispeech/ASR/zapformer/zapformer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 9b559cf7d0..bd8716edb3 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -301,7 +301,6 @@ def compute_projection_overlap(self, verbose: bool = False): N = len(self.encoders) for i in range(N): for j in range(i): - # multipying by lr_scale keeps the scale correct so they will be orthogonal proj_i = self.encoders[i].proj.get_weight() proj_j = self.encoders[j].proj.get_weight() if proj_i.shape[1] > proj_j.shape[1]: @@ -2010,13 +2009,6 @@ def __init__( ) self.left_pad = kernel_size - 1 - self.depthwise_conv.lr_scale = 0.66 # not sure whether to keep this, it wasn't very conclusive. - with torch.no_grad(): - # make the non-central convolution weights much smaller. - k2 = kernel_size // 2 - self.depthwise_conv.weight[..., :k2] *= 0.1 - self.depthwise_conv.weight[..., -k2:] *= 0.1 - # add average-of-all-frames to the "convolution."; it has extra power vs the convolution # because the num frames differs between utterances. self.weighted_mean = WeightedMean(bottleneck_dim, From 29cddb14fbda78ad101ae59301b59ad34c50cafe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Apr 2026 15:45:00 +0800 Subject: [PATCH 1111/1191] Normalize direct grad separately from moving_grad --- .../ASR/zapformer/batched_rubik.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 56ee6710a6..9ad72d9eda 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -218,9 +218,10 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta_ceil = 1. - 1. / (10. + 0.1 * step) - beta1 = min(group["beta1"], beta_ceil) - beta2 = min(group["beta2"], beta_ceil) + beta1_ceil = 1. - 1. / (10. + 0.1 * step) + beta1 = min(group["beta1"], beta1_ceil) + beta2_ceil = step / (step + 1) + beta2 = min(group["beta2"], beta2_ceil) cubic_decay_proportion = group["cubic_decay_proportion"] @@ -234,12 +235,17 @@ def cubic_decay_step(group, state, grad): if "moving_grad" not in state: assert step < 2 state["moving_grad"] = torch.zeros(batch_size, rows, cols, device=grad.device) + state["moving_row_stats"] = torch.ones(batch_size, rows, 1, device=grad.device) + state["moving_col_stats"] = torch.ones(batch_size,1, cols, device=grad.device) state["row_stats"] = torch.ones(batch_size, rows, 1, device=grad.device) - state["col_stats"] = torch.ones(batch_size,1, cols, device=grad.device) + state["col_stats"] = torch.ones(batch_size, 1, cols, device=grad.device) + moving_grad = state["moving_grad"] row_stats = state["row_stats"] col_stats = state["col_stats"] + moving_row_stats = state["moving_row_stats"] + moving_col_stats = state["moving_col_stats"] # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -273,17 +279,28 @@ def cubic_decay_step(group, state, grad): # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP - nesterov = True - if nesterov: - moving_grad_precon = moving_grad_precon + cur_grad_precon - # Now compute "negative_update" which is moving_grad_precon multiplied again by the # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update # rather than as a gradient. it is going to be very close to: # negative_update = moving_grad_precon / invP # but we also update the preconditioner. Note: practically speaking we are multiplying # by the same thing twice though. - negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) + negative_update = normalize_and_update_stats(moving_grad_precon, moving_row_stats, moving_col_stats, beta2, eps) + + norm_grad = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + + moving_grad_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + # moving_grad_assumed_scale is the scale negative_update "should be" if it were decayed moving average of normalized stats, + # with scales: (1-beta1), (1-beta1) beta1, (1-beta1) beta1**2, etc. + + nesterov = True + if nesterov: + # the scale ((1 - beta1**2)**0.5) on grad is derived as follows: + # norm_grad_assumed_scale = (1-beta1) # the scale in a nesterov-type "count current step twice". + # coeff = norm_grad_assumed_scale / negative_grad_assumed_scale + # = ((1 - beta1**2)**0.5) + negative_update = negative_update + norm_grad * ((1 - beta1**2)**0.5) + # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. @@ -292,9 +309,8 @@ def cubic_decay_step(group, state, grad): # This should make divergence less likely. # we ignore nesterov modification for purposes of this formula, it should make little difference anyway # if beta1 is close to 1. - assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) + negative_update = negative_update * (moving_grad_assumed_scale / ((negative_update ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) ans = -lr * negative_update From 6aaa46c476a7312741b4d626f892d7c199427270 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Apr 2026 16:08:42 +0800 Subject: [PATCH 1112/1191] Fix comment --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 9ad72d9eda..e560124c8c 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -297,7 +297,7 @@ def cubic_decay_step(group, state, grad): if nesterov: # the scale ((1 - beta1**2)**0.5) on grad is derived as follows: # norm_grad_assumed_scale = (1-beta1) # the scale in a nesterov-type "count current step twice". - # coeff = norm_grad_assumed_scale / negative_grad_assumed_scale + # coeff = norm_grad_assumed_scale / moving_grad_assumed_scale # = ((1 - beta1**2)**0.5) negative_update = negative_update + norm_grad * ((1 - beta1**2)**0.5) From d9a96ab856d0e316c89425ab30a16ba3d2dad05f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Apr 2026 16:14:54 +0800 Subject: [PATCH 1113/1191] Bug fix, use moving_stats --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index e560124c8c..51fc77340e 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -257,8 +257,8 @@ def cubic_decay_step(group, state, grad): # Looking at this code may give the impression that we are mistakenly # normalizing "twice". Actually we have an "equilibrium argument" why this # is actually OK and will give correctly-normalized data. - row_denom = (row_stats.sqrt() + eps) - col_denom = (col_stats.sqrt() + eps) + row_denom = (moving_row_stats.sqrt() + eps) + col_denom = (moving_col_stats.sqrt() + eps) invP = row_denom * col_denom # inverse preconditioner P moving_grad_precon = moving_grad / invP # preconditioned moving_grad From 70c4178c94d95267067571ebc943d10dda5e5b17 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 28 Apr 2026 18:47:44 +0800 Subject: [PATCH 1114/1191] Apply nesterov also to adam_step and scaling_step --- .../ASR/zapformer/batched_rubik.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 51fc77340e..6d90e2d4f8 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -352,9 +352,18 @@ def scaling_step(group, param, state, grad): scale_grad = (grad * param.detach()).sum(dim=dims, keepdim=True) scale_grad_buf.mul_(momentum).add_(scale_grad) # simple momentum + nesterov = True + if nesterov: + # simple interpretation of nesterov: do an extra step of + # moving-average on scale_grad_buf, with scale_grad, like double-counting + # it. + negative_update = (scale_grad_buf * momentum + scale_grad).sign() + else: + negative_update = scale_grad_buf.sign() + old_scale = scale.clone() - scale.mul_(1. - lr * scalar_scale * scale_grad_buf.sign()) + scale.mul_(1. - lr * scalar_scale * negative_update) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale @@ -388,7 +397,14 @@ def adam_step(group, state, grad): exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) denom = exp_avg_sq.sqrt() + eps - return -lr * (exp_avg / denom) + nesterov = True + if nesterov: + # this is similar to double-counting grad + moving_grad = exp_avg * beta1 + grad * (1-beta1) + else: + moving_grad = exp_avg + + return -lr * (moving_grad / denom) From d2d2f810c0ad51b1dac4b68d435530631a1b239c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 1 May 2026 22:21:33 +0800 Subject: [PATCH 1115/1191] Restore code to down weight non-central depthwise_conv weights on initialization. --- egs/librispeech/ASR/zapformer/zapformer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index bd8716edb3..8065acdb20 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -2009,6 +2009,12 @@ def __init__( ) self.left_pad = kernel_size - 1 + with torch.no_grad(): + # make the non-central convolution weights much smaller. + k = kernel_size // 2 + self.depthwise_conv.weight[..., :k] *= 0.1 + self.depthwise_conv.weight[..., -k:] *= 0.1 + # add average-of-all-frames to the "convolution."; it has extra power vs the convolution # because the num frames differs between utterances. self.weighted_mean = WeightedMean(bottleneck_dim, From 91a2ef47b46a1dda0cc008d24fa6af7f5286f098 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 3 May 2026 12:53:26 +0800 Subject: [PATCH 1116/1191] Make compute_projection_overlap more efficient. --- egs/librispeech/ASR/zapformer/zapformer.py | 50 ++++++++++++++-------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 8065acdb20..e54f04964b 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -294,33 +294,45 @@ def forward( def compute_projection_overlap(self, verbose: bool = False): + # This computes a quantity that we'll use as an auxiliary loss. + # It ensures that the projections from more-subsampled sequences "contain" enough of the + # projections from the less-subsampled sequences-- specifically the direction where + # all the less-subsampled projections co-vary in the same way, e.g. if there are + # two frames, that the two frames are identical. + min_overlap = 0.66 # we can tune this tot_loss = 0.0 # between pairs of encoders N = len(self.encoders) + + covs = [] + ranks = [] + for i in range(N): + proj_i = self.encoders[i].proj.get_weight() + cov_i = torch.matmul(proj_i.t(), proj_i) + covs.append(cov_i) + ranks.append(proj_i.shape[0]) + for i in range(N): for j in range(i): - proj_i = self.encoders[i].proj.get_weight() - proj_j = self.encoders[j].proj.get_weight() - if proj_i.shape[1] > proj_j.shape[1]: - proj_i, proj_j = proj_j, proj_i # swap them - in_dim_i = proj_i.shape[1] # now this is <= proj_j.shape[1] - in_dim_j = proj_j.shape[1] - assert in_dim_i <= in_dim_j - assert in_dim_j % in_dim_i == 0 # in_dims must be multiples of each other - R = in_dim_j // in_dim_i # e.g. 1, 2, 4 - assert R in [1, 2, 4, 8] - - proj_i = proj_i.repeat(1, R).reshape(proj_i.shape[0], proj_j.shape[1]) * (R ** -0.5) - # proj_i should still have orthogonal rows. - # now proj_j and proj_i have same dimension one (in_dim) - cov_i = torch.matmul(proj_i.t(), proj_i) - cov_j = torch.matmul(proj_j.t(), proj_j) - # denominator is the minimum of the two rather than their geometric mean, + cov_i, cov_j = covs[i], covs[j] + rank_i, rank_j = ranks[i], ranks[j] + if cov_i.shape[0] > cov_j.shape[0]: + cov_i, cov_j = cov_j, cov_i + rank_i, rank_j = rank_j, rank_i + dim_i = cov_i.shape[0] # now this is <= proj_j.shape[0] + dim_j = cov_j.shape[0] + assert dim_i <= dim_j + assert dim_j % dim_i == 0 # dims must be multiples of each other (these are the + # feature dimension prior to project, i.e. the larger dimensions.) + R = dim_j // dim_i # e.g. 1, 2, 4 + assert R in [1, 2, 4, 8, 16] + cov_i = cov_i.repeat(R, R) * (1. / R) + # denominator is the minimum of the two ranks, # because due to the orthogonal constraint, the maximum possible value of (cov_i * cov_j).sum() would be the - # smaller of the two dimension. - cosine = (cov_i * cov_j).sum() / proj_i.shape[0] + # smaller of the two ranks. + cosine = (cov_i * cov_j).sum() / min(rank_i, rank_j) loss = (min_overlap - cosine).relu() tot_loss = tot_loss + loss From 14a91062f47c9e7422c5303103c2ac7f0528a0e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 May 2026 14:04:30 +0800 Subject: [PATCH 1117/1191] Remove parameter names from batched_rubik (not functional anyway), simplify its config to just scale_limits, and propagate change RE beta2_ceil to rubik.py. --- .../ASR/zapformer/batched_rubik.py | 159 +++--------------- egs/librispeech/ASR/zapformer/rubik.py | 19 +-- egs/librispeech/ASR/zapformer/train.py | 3 +- 3 files changed, 31 insertions(+), 150 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 6d90e2d4f8..2c8541e591 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -325,8 +325,7 @@ def scaling_step(group, param, state, grad): # (iii) update the parameter scale, which means shrinking or growing the whole tensor lr = group["lr"] momentum = group["scale_momentum"] # e.g. 0.95 - is_weight = grad.ndim >= 2 - min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] + min_scale, max_scale = group["scale_limits"] # the scaling factor is implicitly a scalar; apply scalar_scale to its # learning rate. scalar_scale = group["scalar_scale"] @@ -445,8 +444,7 @@ def __init__( cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, - weight_scale_limits=(0.03, 0.15), - bias_scale_limits=(0.03, 0.15), + scale_limits=(0.03, 0.15), scalar_scale=0.05, adam_beta1=0.98, adam_beta2=0.98, @@ -459,127 +457,14 @@ def __init__( cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, - weight_scale_limits=weight_scale_limits, - bias_scale_limits=bias_scale_limits, + scale_limits=scale_limits, scalar_scale=scalar_scale, adam_beta1=adam_beta1, adam_beta2=adam_beta2, scale_momentum=scale_momentum, ) - param_groups, parameters_names = self._get_names_of_parameters(params) - super(BatchedRubik, self).__init__(param_groups, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - - def _get_names_of_parameters( - self, params_or_named_params - ) -> Tuple[List[Dict], List[List[str]]]: - """ - Args: - params_or_named_params: according to the way TransformedAdam is initialized in train.py, - this argument could be one of following 4 cases, - case 1, a generator of parameter, e.g.: - optimizer = TransformedAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 2, a list of parameter groups with different config, e.g.: - model_param_groups = [ - {'params': model.encoder.parameters(), 'lr': 0.05}, - {'params': model.decoder.parameters(), 'lr': 0.01}, - {'params': model.joiner.parameters(), 'lr': 0.03}, - ] - optimizer = TransformedAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) - - case 3, a generator of named_parameter, e.g.: - optimizer = TransformedAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 4, a list of named_parameter groups with different config, e.g.: - model_named_param_groups = [ - {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, - {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, - {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, - ] - optimizer = TransformedAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) - - For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. - For case 3 and case 4, firstly, names and params are extracted from input named_params, - then, these extracted params are used to initialize the underlying torch.optimizer, - and these extracted names are mainly used by function - `_show_gradient_dominating_parameter` - - Returns: - Returns a tuple containing 2 elements: - - `param_groups` with type List[Dict], each Dict element is a parameter group. - An example of `param_groups` could be: - [ - {'params': `one iterable of Parameter`, 'lr': 0.05}, - {'params': `another iterable of Parameter`, 'lr': 0.08}, - {'params': `a third iterable of Parameter`, 'lr': 0.1}, - ] - - `param_gruops_names` with type List[List[str]], - each `List[str]` is for a group['params'] in param_groups, - and each `str` is the name of a parameter. - A dummy name "foo" is related to each parameter, - if input are params without names, i.e. case 1 or case 2. - """ - # variable naming convention in this function: - # p is short for param. - # np is short for named_param. - # p_or_np is short for param_or_named_param. - # cur is short for current. - # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. - # groups is a List[group] - - iterable_or_groups = list(params_or_named_params) - if len(iterable_or_groups) == 0: - raise ValueError("optimizer got an empty parameter list") - - # The first value of returned tuple. A list of dicts containing at - # least 'params' as a key. - param_groups = [] - - # The second value of returned tuple, - # a List[List[str]], each sub-List is for a group. - param_groups_names = [] - - if not isinstance(iterable_or_groups[0], dict): - # case 1 or case 3, - # the input is an iterable of parameter or named parameter. - param_iterable_cur_group = [] - param_names_cur_group = [] - for p_or_np in iterable_or_groups: - if isinstance(p_or_np, tuple): - # case 3 - name, param = p_or_np - else: - # case 1 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) - param_groups.append({"params": param_iterable_cur_group}) - param_groups_names.append(param_names_cur_group) - else: - # case 2 or case 4 - # the input is groups of parameter or named parameter. - for cur_group in iterable_or_groups: - if "named_params" in cur_group: - name_list = [x[0] for x in cur_group["named_params"]] - p_list = [x[1] for x in cur_group["named_params"]] - del cur_group["named_params"] - cur_group["params"] = p_list - else: - assert "params" in cur_group - name_list = ["foo" for _ in cur_group["params"]] - param_groups.append(cur_group) - param_groups_names.append(name_list) - - return param_groups, param_groups_names - - + super(BatchedRubik, self).__init__(params, defaults) def __setstate__(self, state): super(BatchedRubik, self).__setstate__(state) @@ -599,27 +484,29 @@ def step(self, closure=None): batch = True - for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - for p, state, _names in batches: - grad = p.grad + for group in self.param_groups: - try: - cur_step = state["step"] - except KeyError: - state["step"] = 0 - cur_step = 0 + for p in group['params']: + state = self.state[p] + grad = p.grad - if p.numel() == p.shape[0]: - # "scalar_scale" the assumed parameter scale used for - # scalars, in this case it just acts as a multiplier on - # the learning rate. - p += group["scalar_scale"] * adam_step(group, state, grad) - else: - p += scaling_step(group, p.detach(), state, grad) - state["step"] = cur_step + 1 + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + if p.numel() == p.shape[0]: + # "scalar_scale" the assumed parameter scale used for + # scalars, in this case it just acts as a multiplier on + # the learning rate. + p += group["scalar_scale"] * adam_step(group, state, grad) + else: + p += scaling_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 return loss diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 3c54f7c1f4..6cfa1cd6a2 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -113,9 +113,10 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta_ceil = 1. - 1. / (10. + 0.2 * step) - beta1 = min(group["beta1"], beta_ceil) - beta2 = min(group["beta2"], beta_ceil) + beta1_ceil = 1. - 1. / (10. + 0.2 * step) + beta1 = min(group["beta1"], beta1_ceil) + beta2_ceil = step / (step + 1) + beta2 = min(group["beta2"], beta2_ceil) cubic_decay_proportion = group["cubic_decay_proportion"] linear_decay_proportion = 1. - cubic_decay_proportion @@ -203,8 +204,7 @@ def scaling_step(group, param, state, grad): # (iii) update the parameter scale, which means shrinking or growing the whole tensor lr = group["lr"] momentum = group["scale_momentum"] # e.g. 0.95 - is_weight = grad.ndim >= 2 - min_scale, max_scale = group["weight_scale_limits"] if is_weight else group["bias_scale_limits"] + min_scale, max_scale = group["scale_limits"] # the scaling factor is implicitly a scalar; apply scalar_scale to its # learning rate. scalar_scale = group["scalar_scale"] @@ -295,8 +295,7 @@ def __init__( cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, - weight_scale_limits=(0.03, 0.15), - bias_scale_limits=(0.03, 0.15), + scale_limits=(0.03, 0.15), scalar_scale=0.05, adam_beta1=0.98, adam_beta2=0.98, @@ -308,8 +307,7 @@ def __init__( cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, - weight_scale_limits=weight_scale_limits, - bias_scale_limits=bias_scale_limits, + scale_limits=scale_limits, scalar_scale=scalar_scale, adam_beta1=adam_beta1, adam_beta2=adam_beta2, @@ -348,9 +346,6 @@ def step(self, closure=None): state["step"] = 0 cur_step = 0 - def u(x): - return x.unsqueeze(0) - if p.numel() == 1: # "scalar_scale" the assumed parameter scale used for # scalars, in this case it just acts as a multiplier on diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 7aa4406cd0..399397528d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1433,13 +1433,12 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = BatchedRubik( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), lr=params.base_lr, cubic_decay_proportion=0.8, beta1=0.995, ) - if True: # Work out copies_per_epoch copies_per_epoch = [ ] From 733a896364396146891153483a19d0b773b62811 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 May 2026 14:33:05 +0800 Subject: [PATCH 1118/1191] Bug fixes and properly sync rubik.py with batched_rubik.py --- .../ASR/zapformer/batched_rubik.py | 72 +++++++------------ egs/librispeech/ASR/zapformer/rubik.py | 23 +++++- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 2c8541e591..523aab1840 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -43,10 +43,9 @@ def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): + def batched_params(self, param_list): """ - This function returns (technically, yields) a list of - of tuples (p, state), where + This function returns (technically, yields) a list of tuples (p, state), where p is a `fake` parameter that is stacked (over axis 0) from real parameters that share the same shape, and its gradient is also stacked; `state` is the state corresponding to this batch of parameters @@ -65,7 +64,7 @@ def batched_params(self, param_group, group_params_names): you can do: with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: + for p, state in batches: ... @@ -78,31 +77,18 @@ def batched_params(self, param_group, group_params_names): batches = defaultdict( list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): + for p in param_list: key = (str(p.dtype), *p.shape) batches[key].append(p) - batches_names[key].append(named_p) - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - stacked_params_dict = dict() - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # tuples will contain tuples of (stacked_param, state), # one for each batch in `batches`. tuples = [] - for batch, batch_names in zip(batches, batches_names): + for batch in batches.values(): p = batch[0] # we arbitrarily store the state in the # state corresponding to the 1st parameter in the @@ -113,12 +99,11 @@ def batched_params(self, param_group, group_params_names): [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] ) p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) + tuples.append((p_stacked, state)) yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for ((stacked_params, _state), batch) in zip(tuples, batches.values()): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -484,29 +469,26 @@ def step(self, closure=None): batch = True - for group in self.param_groups: - - for p in group['params']: - state = self.state[p] - grad = p.grad - - - try: - cur_step = state["step"] - except KeyError: - state["step"] = 0 - cur_step = 0 - - if p.numel() == p.shape[0]: - # "scalar_scale" the assumed parameter scale used for - # scalars, in this case it just acts as a multiplier on - # the learning rate. - p += group["scalar_scale"] * adam_step(group, state, grad) - else: - p += scaling_step(group, p.detach(), state, grad) - - state["step"] = cur_step + 1 + with self.batched_params(group["params"]) as batches: + for p, state in batches: + grad = p.grad + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + if p.numel() == p.shape[0]: + # "scalar_scale" the assumed parameter scale used for + # scalars, in this case it just acts as a multiplier on + # the learning rate. + p += group["scalar_scale"] * adam_step(group, state, grad) + else: + p += scaling_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 return loss diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 6cfa1cd6a2..17ead071cb 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -113,7 +113,7 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta1_ceil = 1. - 1. / (10. + 0.2 * step) + beta1_ceil = 1. - 1. / (10. + 0.1 * step) beta1 = min(group["beta1"], beta1_ceil) beta2_ceil = step / (step + 1) beta2 = min(group["beta2"], beta2_ceil) @@ -231,7 +231,16 @@ def scaling_step(group, param, state, grad): old_scale = scale.clone() - scale.mul_(1. - lr * scalar_scale * scale_grad_buf.sign()) + nesterov = True + if nesterov: + # simple interpretation of nesterov: do an extra step of + # moving-average on scale_grad_buf, with scale_grad, like double-counting + # it. + negative_update = (scale_grad_buf * momentum + scale_grad).sign() + else: + negative_update = scale_grad_buf.sign() + + scale.mul_(1. - lr * scalar_scale * negative_update) scale.clamp_(min=min_scale, max=max_scale) scale_ratio = scale / old_scale @@ -265,7 +274,15 @@ def adam_step(group, state, grad): exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) denom = exp_avg_sq.sqrt() + eps - return -lr * (exp_avg / denom) + nesterov = True + if nesterov: + # this is similar to double-counting grad + moving_grad = exp_avg * beta1 + grad * (1-beta1) + else: + moving_grad = exp_avg + + return -lr * (moving_grad / denom) + From 20e3f4fa78d4fd36612a1c750a017117f9948f92 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 May 2026 14:37:44 +0800 Subject: [PATCH 1119/1191] Fix various comments. --- egs/librispeech/ASR/zapformer/alternating_spec_augment.py | 4 ++-- egs/librispeech/ASR/zapformer/zapformer_utils.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index b27c0336eb..0214f80065 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -9,8 +9,8 @@ class AlternatingSpecAugment(torch.nn.Module): """ AlternatingSpecAugment is a different version of feature-masking and frame-masking - aspects of SpecAugment, without the time warping for now (we use time_warp - from lhotse which is the same as the original SpecAugment). + aspects of SpecAugment, without the time warping for now (we use code for time_warp + adapted from lhotse which is the same as the original SpecAugment). The main difference is in how it selects the regions to be masked, they are selected in a way that usually ensures there is a good amount of space between successive masks. diff --git a/egs/librispeech/ASR/zapformer/zapformer_utils.py b/egs/librispeech/ASR/zapformer/zapformer_utils.py index 6d04f95c80..0470b74690 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_utils.py +++ b/egs/librispeech/ASR/zapformer/zapformer_utils.py @@ -31,10 +31,9 @@ class SoftmaxFunction(torch.autograd.Function): """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. + A memory-efficient implementation of softmax that does not require + storing anything as fp32 in autocast mode. """ - @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) From dd07263d13940fc617461bc029838ccc2bd3a26e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 May 2026 16:48:56 +0800 Subject: [PATCH 1120/1191] Remove input sigmoid-scaling in ConvolutionModule. --- egs/librispeech/ASR/zapformer/zapformer.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index e54f04964b..b6e465f80a 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1985,18 +1985,14 @@ def __init__( self.in_proj = nn.Linear( channels, - 3 * bottleneck_dim, + bottleneck_dim, ) # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. self.activation1 = Identity() # for diagnostics - self.sigmoid1 = nn.Sigmoid() - - self.sigmoid2 = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics + self.sigmoid = nn.Sigmoid() if not causal: @@ -2059,13 +2055,9 @@ def forward( x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) - x, s, y = x.chunk(3, dim=2) - s = self.sigmoid1(s) - y = self.sigmoid2(y) + x, y = x.chunk(2, dim=2) + y = self.sigmoid(y) x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - # x: (time, batch, channels) # Caution: this module is not completely @@ -2124,10 +2116,8 @@ def streaming_forward( """ x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) - x, s, y = x.chunk(3, dim=2) - s = self.sigmoid1(s) - y = self.sigmoid2(y) - x = x * s + x, y = x.chunk(2, dim=2) + y = self.sigmoid(y) if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) From e2cfc29b59c0ab0bdb0f4e9adb76a57de4927df4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 May 2026 17:14:08 +0800 Subject: [PATCH 1121/1191] Bug fix --- egs/librispeech/ASR/zapformer/zapformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index b6e465f80a..2c7b6be9ff 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -1985,7 +1985,7 @@ def __init__( self.in_proj = nn.Linear( channels, - bottleneck_dim, + 2 * bottleneck_dim, ) # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. From 77111357dea532ea24ae07a5078c8000052a3172 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 13 May 2026 22:13:48 +0800 Subject: [PATCH 1122/1191] Update RESULTS.md to add basic zapformer recipe --- egs/librispeech/ASR/RESULTS.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index e5a82dfda2..ded7065b38 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,30 @@ ## Results +### zapformer (zapformer + pruned-transducer w/ CTC) + +Note: --num-real-epochs 40 takes about the same time as 20 epochs with the zipformer CR-CTC recipe. +(each epoch is really 3 epochs due to speed-perturb). So the time for training will be roughly 40% +of the old zipformer recipe. The "--epoch 13" reported below is the last epoch, the smaller +number of epochs has to do with the --min-copies,--max-copies, we will add this into the +report later (later epochs take more real computation time because they make different SpecAug +copies of the data.) + +# (non-streaming) +./zapformer/train.py --world-size 4 \ + --min-copies 1 --max-copies 8 --num-real-epochs 40 \ + --base-lr=0.023 --batches-per-epoch 2400 --start-epoch 1 --use-fp16 1 \ + --exp-dir zapformer/exp \ + --use-ctc 1 --use-transducer 1 \ + --base-dim 64 --ctc-loss-scale 0.2 \ + --full-libri 1 --max-duration 1200 --master-port 43039 + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| greedy_search | 1.83 | 3.75 | --epoch 13 --avg 3 | + + + + ### zipformer (zipformer + pruned-transducer w/ CR-CTC) See for more details. From 03d4a70b43704b72d54823a0441c27a005587452 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 16 May 2026 14:19:03 +0800 Subject: [PATCH 1123/1191] Fix formula with linear_alpha having the wrong sign; add some debug code. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 9 ++++++--- egs/librispeech/ASR/zapformer/rubik.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 523aab1840..4aa203b7e8 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -253,13 +253,16 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad_precon) - cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) + cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*cubic_decay_proportion) # cubic_alpha shape: (batch_size, 1, 1) linear_alpha = -(1-beta1) - cubic_alpha # will be negative. + if (step < 1000 and step % 100 == 0) or (step < 100 and step % 10 == 0) or (step % 1000 == 0): + logging.info(f"step={step}, shape={moving_grad_precon.shape}, linear_alpha = {linear_alpha.flatten()}") + moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. - linear_alpha) + moving_grad_precon.mul_(1. + linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP @@ -425,7 +428,7 @@ def __init__( self, params, lr=1.2e-02, - beta1=0.995, + beta1=0.99, cubic_decay_proportion=0.8, beta2=0.98, eps=1.0e-08, diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 17ead071cb..08d44c8b64 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -157,13 +157,13 @@ def cubic_decay_step(group, state, grad): prod3 = scaled_three_way_product(moving_grad_precon) - cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) + cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*cubic_decay_proportion) # cubic_alpha shape: (batch_size, 1, 1). it will be negative. linear_alpha = -(1-beta1) - cubic_alpha # will be negative. moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. - linear_alpha) + moving_grad_precon.mul_(1. + linear_alpha) # update moving_grad as interpolation between linear decay and cubic decay. moving_grad[:] = moving_grad_precon * invP From 40b58c0b58d3befcf96fe67a013dc0e1d2a10201 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 16 May 2026 17:39:56 +0800 Subject: [PATCH 1124/1191] Increase cubic_decay_proportion from 0.8 to 1.0. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 3 --- egs/librispeech/ASR/zapformer/train.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 4aa203b7e8..ff8339afc8 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -258,9 +258,6 @@ def cubic_decay_step(group, state, grad): linear_alpha = -(1-beta1) - cubic_alpha # will be negative. - if (step < 1000 and step % 100 == 0) or (step < 100 and step % 10 == 0) or (step % 1000 == 0): - logging.info(f"step={step}, shape={moving_grad_precon.shape}, linear_alpha = {linear_alpha.flatten()}") - moving_grad_precon.add_(prod3 * cubic_alpha) moving_grad_precon.mul_(1. + linear_alpha) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 399397528d..a41ba7c4e3 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), lr=params.base_lr, - cubic_decay_proportion=0.8, + cubic_decay_proportion=1.0, beta1=0.995, ) From 8e89313ef7477fa2e749d41cd483c5ea7d79c2d4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 14 May 2026 11:06:23 +0800 Subject: [PATCH 1125/1191] Increase overlap minimum from .66 to .70 --- egs/librispeech/ASR/zapformer/zapformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 2c7b6be9ff..f6bfe50696 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -300,7 +300,7 @@ def compute_projection_overlap(self, verbose: bool = False): # all the less-subsampled projections co-vary in the same way, e.g. if there are # two frames, that the two frames are identical. - min_overlap = 0.66 # we can tune this + min_overlap = 0.7 # we can tune this tot_loss = 0.0 # between pairs of encoders From 7a24ff95374cbc4fdd9579f826417941152ff962 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 May 2026 14:29:46 +0800 Subject: [PATCH 1126/1191] Double nesterov scale. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index ff8339afc8..f78fdb11ae 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -278,13 +278,13 @@ def cubic_decay_step(group, state, grad): # moving_grad_assumed_scale is the scale negative_update "should be" if it were decayed moving average of normalized stats, # with scales: (1-beta1), (1-beta1) beta1, (1-beta1) beta1**2, etc. - nesterov = True - if nesterov: + nesterov = 2.0 # 1.0 would be standard nesterov + if nesterov != 0.0: # the scale ((1 - beta1**2)**0.5) on grad is derived as follows: # norm_grad_assumed_scale = (1-beta1) # the scale in a nesterov-type "count current step twice". # coeff = norm_grad_assumed_scale / moving_grad_assumed_scale # = ((1 - beta1**2)**0.5) - negative_update = negative_update + norm_grad * ((1 - beta1**2)**0.5) + negative_update = negative_update + norm_grad * (nesterov * ((1 - beta1**2)**0.5)) # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what From d6139afa92fd42f58a71c18c33ac98c246f6cfdc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 May 2026 17:55:02 +0800 Subject: [PATCH 1127/1191] Take simpler version of batched_rubik.py that only has one set of stats. --- .../ASR/zapformer/batched_rubik.py | 145 ++++++++++-------- 1 file changed, 83 insertions(+), 62 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index f78fdb11ae..fb8136c926 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -23,6 +23,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch import Tensor from torch.optim import Optimizer @@ -118,15 +119,33 @@ def compute_prod3(x): x2 = torch.matmul(x.transpose(-2, -1), x) return torch.matmul(x, x2) -def three_way_product(x): - """ returns the 3-way matrix product x @ x.t() @ x """ - assert x.ndim >= 2 - if x.shape[0] <= x.shape[1]: - x2 = torch.matmul(x, x.transpose(-2, -1)) - return torch.matmul(x2, x) + +def _three_way_product_chunk(x_chunk): + """Core computation: x_chunk @ x_chunk.T @ x_chunk for a single chunk.""" + if x_chunk.shape[-2] <= x_chunk.shape[-1]: + x2 = torch.matmul(x_chunk, x_chunk.transpose(-2, -1)) + return torch.matmul(x2, x_chunk) else: - x2 = torch.matmul(x.transpose(-2, -1), x) - return torch.matmul(x, x2) + x2 = torch.matmul(x_chunk.transpose(-2, -1), x_chunk) + return torch.matmul(x_chunk, x2) + + +def three_way_product(x, chunk_size=32): + """ returns the 3-way matrix product x @ x.t() @ x + + Processes the batch dimension in chunks to reduce peak GPU memory usage. + The intermediate x @ x.T has shape (batch, rows, rows) which can be very + large; chunking keeps peak memory proportional to chunk_size instead of batch. + """ + assert x.ndim >= 2 + batch = x.shape[0] + if batch <= chunk_size: + return _three_way_product_chunk(x) + results = [] + for start in range(0, batch, chunk_size): + end = min(start + chunk_size, batch) + results.append(_three_way_product_chunk(x[start:end])) + return torch.cat(results, dim=0) def scaled_three_way_product(x): @@ -180,6 +199,25 @@ def matrix_shape(shape): assert False, shape +def update_halfnorm_precon(x, row_stats, col_stats, beta2, eps): + """ + half-normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + row_stats.mul_(beta2).add_(x.abs().mean(dim=2, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats.sqrt() + eps) + x = x / row_denom + col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats.sqrt() + eps) + return x / col_denom, row_denom, col_denom + + def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ @@ -192,11 +230,13 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): Returns: normalized x, shape: (batch_size, rows, cols) """ - row_stats.mul_(beta2).add_((x ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) - x = x / (row_stats.sqrt() + eps) - col_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) - return x / (col_stats.sqrt() + eps) - + row_stats.mul_(beta2).add_(x.abs().mean(dim=2, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats + eps) + x = x / row_denom + col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats + eps) + x = x / col_denom + return x, row_denom, col_denom def cubic_decay_step(group, state, grad): @@ -220,8 +260,6 @@ def cubic_decay_step(group, state, grad): if "moving_grad" not in state: assert step < 2 state["moving_grad"] = torch.zeros(batch_size, rows, cols, device=grad.device) - state["moving_row_stats"] = torch.ones(batch_size, rows, 1, device=grad.device) - state["moving_col_stats"] = torch.ones(batch_size,1, cols, device=grad.device) state["row_stats"] = torch.ones(batch_size, rows, 1, device=grad.device) state["col_stats"] = torch.ones(batch_size, 1, cols, device=grad.device) @@ -229,62 +267,44 @@ def cubic_decay_step(group, state, grad): moving_grad = state["moving_grad"] row_stats = state["row_stats"] col_stats = state["col_stats"] - moving_row_stats = state["moving_row_stats"] - moving_col_stats = state["moving_col_stats"] - # add the grad to the moving-average grad; the scaling factor used here - # doesn't matter as it all gets normalized later. - moving_grad.add_(grad) - # We'll scale both before and after the cubic decay; this can be viewed as - # doing the cubic decay in a preconditioned space where the preconditioner - # is 1 / row_col_denom. (The row and column stats will be updated later). - # Looking at this code may give the impression that we are mistakenly - # normalizing "twice". Actually we have an "equilibrium argument" why this - # is actually OK and will give correctly-normalized data. - row_denom = (moving_row_stats.sqrt() + eps) - col_denom = (moving_col_stats.sqrt() + eps) - invP = row_denom * col_denom # inverse preconditioner P + # add grad again, like nesterov... just emphasize grad a bit more while also taking into account moving_grad.. + norm_grad, row_denom, col_denom = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - moving_grad_precon = moving_grad / invP # preconditioned moving_grad - cur_grad_precon = grad / invP # this step's contribution to moving_grad_precon, used for nesterov modification + denom_prod = (row_denom * col_denom) + invP = denom_prod.sqrt() # this sqrt is because we only want to do half of it before and half of it after; they already had .sqrt() done to them. + + norm_grad_precon = norm_grad * invP # undoes half of the normalization + + # add the grad to the moving-average grad; the scaling factor used here + # doesn't matter as it all gets normalized later. + moving_grad.add_(norm_grad_precon) # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. - prod3 = scaled_three_way_product(moving_grad_precon) + prod3 = scaled_three_way_product(moving_grad) - cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*cubic_decay_proportion) + cubic_alpha = clip_alpha(moving_grad, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) # cubic_alpha shape: (batch_size, 1, 1) linear_alpha = -(1-beta1) - cubic_alpha # will be negative. - moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. + linear_alpha) + moving_grad.add_(prod3 * cubic_alpha) + moving_grad.mul_(1. + linear_alpha) - # update moving_grad as interpolation between linear decay and cubic decay. - moving_grad[:] = moving_grad_precon * invP + delta = moving_grad / invP # re-add the half of the normalizatin that we removed - # Now compute "negative_update" which is moving_grad_precon multiplied again by the - # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update - # rather than as a gradient. it is going to be very close to: - # negative_update = moving_grad_precon / invP - # but we also update the preconditioner. Note: practically speaking we are multiplying - # by the same thing twice though. - negative_update = normalize_and_update_stats(moving_grad_precon, moving_row_stats, moving_col_stats, beta2, eps) - - norm_grad = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + nesterov = 2.0 # 1.0 would be normal nesterov + if nesterov != 0.0: + delta = delta + (nesterov / beta1) * norm_grad # not in-place. - moving_grad_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - # moving_grad_assumed_scale is the scale negative_update "should be" if it were decayed moving average of normalized stats, - # with scales: (1-beta1), (1-beta1) beta1, (1-beta1) beta1**2, etc. + delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - nesterov = 2.0 # 1.0 would be standard nesterov - if nesterov != 0.0: - # the scale ((1 - beta1**2)**0.5) on grad is derived as follows: - # norm_grad_assumed_scale = (1-beta1) # the scale in a nesterov-type "count current step twice". - # coeff = norm_grad_assumed_scale / moving_grad_assumed_scale - # = ((1 - beta1**2)**0.5) - negative_update = negative_update + norm_grad * (nesterov * ((1 - beta1**2)**0.5)) + #if True: + # + #if step < 5 or (step < 500 and step % 10 == 0): + #logging.info(f"shape={delta.shape}, grad rms is {(grad ** 2).mean(dim=(1,2)).sqrt()}, norm_grad rms is {(norm_grad ** 2).mean(dim=(1,2)).sqrt()}, norm_grad_precon rms is {(norm_grad_precon ** 2).mean(dim=(1,2)).sqrt()}, delta rms is {(delta ** 2).mean(dim=(1,2)).sqrt()}, moving_grad rms is {(moving_grad ** 2).mean(dim=(1,2)).sqrt()}, row_stats_sqrt rms is {row_stats.sqrt().mean(dim=(1,2))}, col_stats sqrt rms is {col_stats.sqrt().mean(dim=(1,2))}") # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what @@ -295,9 +315,9 @@ def cubic_decay_step(group, state, grad): # we ignore nesterov modification for purposes of this formula, it should make little difference anyway # if beta1 is close to 1. - negative_update = negative_update * (moving_grad_assumed_scale / ((negative_update ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) + delta = delta * (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - ans = -lr * negative_update + ans = -lr * delta return ans.reshape(orig_shape) @@ -391,8 +411,6 @@ def adam_step(group, state, grad): return -lr * (moving_grad / denom) - - class BatchedRubik(BatchedOptimizer): """ Implements a batched version of the Rubik optimizer. @@ -425,8 +443,8 @@ def __init__( self, params, lr=1.2e-02, - beta1=0.99, - cubic_decay_proportion=0.8, + beta1=0.995, + cubic_decay_proportion=1.0, beta2=0.98, eps=1.0e-08, scale_limits=(0.03, 0.15), @@ -474,6 +492,9 @@ def step(self, closure=None): for p, state in batches: grad = p.grad + if dist.is_initialized(): + dist.all_reduce(grad, op=dist.ReduceOp.AVG) + try: cur_step = state["step"] except KeyError: From 1c80a841a62e13fbde6e54239556683a7644cfdb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 May 2026 18:16:26 +0800 Subject: [PATCH 1128/1191] REvert nesterov scale to 1. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index fb8136c926..b577860f77 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -295,9 +295,9 @@ def cubic_decay_step(group, state, grad): delta = moving_grad / invP # re-add the half of the normalizatin that we removed - nesterov = 2.0 # 1.0 would be normal nesterov - if nesterov != 0.0: - delta = delta + (nesterov / beta1) * norm_grad # not in-place. + nesterov = True + if nesterov: + delta = beta1 * delta + norm_grad # not in-place. delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) From 857971af913dab023befa655dbc4b240aa02a055 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 May 2026 19:04:45 +0800 Subject: [PATCH 1129/1191] Double nesterov scale to 2.0. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index b577860f77..b0a3c79340 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -295,9 +295,9 @@ def cubic_decay_step(group, state, grad): delta = moving_grad / invP # re-add the half of the normalizatin that we removed - nesterov = True - if nesterov: - delta = beta1 * delta + norm_grad # not in-place. + nesterov = 2.0 + if nesterov != 0.0: + delta = delta + (nesterov / beta1) * norm_grad # not in-place. delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) From c4e2806268edf0ba2c0c15e7f623d0ef234b2927 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 May 2026 20:37:17 +0800 Subject: [PATCH 1130/1191] Decrease nesterov scale to 0.66. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index b0a3c79340..79e79d6b2a 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -295,7 +295,7 @@ def cubic_decay_step(group, state, grad): delta = moving_grad / invP # re-add the half of the normalizatin that we removed - nesterov = 2.0 + nesterov = 0.66 if nesterov != 0.0: delta = delta + (nesterov / beta1) * norm_grad # not in-place. From f02f152b830e37fd9840cf33aae79fc03845ac31 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 May 2026 22:51:40 +0800 Subject: [PATCH 1131/1191] Remove linear decay; nesterov_scale 0.66->1.0; set cubic_decay_proportion=0.5, beta1=0.99, beta1_ceil const 0.1->0.2. --- .../ASR/zapformer/batched_rubik.py | 29 +++++++++++-------- egs/librispeech/ASR/zapformer/train.py | 4 +-- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 79e79d6b2a..223357a1be 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -243,14 +243,13 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta1_ceil = 1. - 1. / (10. + 0.1 * step) + beta1_ceil = 1. - 1. / (10. + 0.2 * step) beta1 = min(group["beta1"], beta1_ceil) beta2_ceil = step / (step + 1) beta2 = min(group["beta2"], beta2_ceil) cubic_decay_proportion = group["cubic_decay_proportion"] - linear_decay_proportion = 1. - cubic_decay_proportion orig_shape = grad.shape batch_size = orig_shape[0] @@ -285,19 +284,25 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - cubic_alpha = clip_alpha(moving_grad, prod3, alpha=-(1-beta1)*(1. - linear_decay_proportion)) - # cubic_alpha shape: (batch_size, 1, 1) - linear_alpha = -(1-beta1) - cubic_alpha # will be negative. + debug = (step % 40 == 0) + if debug: + moving_grad_norm = (moving_grad ** 2).mean(dim=(1,2)).sqrt() + + cubic_alpha = clip_alpha(moving_grad, prod3, alpha=-(1-beta1)*cubic_decay_proportion) + # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) - moving_grad.mul_(1. + linear_alpha) + + if debug: + moving_grad_norm_rel_change = 1. - (moving_grad ** 2).mean(dim=(1,2)).sqrt() / moving_grad_norm + logging.info(f"shape={prod3.shape}, moving_grad_rel_change={moving_grad_norm_rel_change}, vs. target {(1-beta1)}") delta = moving_grad / invP # re-add the half of the normalizatin that we removed - nesterov = 0.66 - if nesterov != 0.0: - delta = delta + (nesterov / beta1) * norm_grad # not in-place. + nesterov = True + if nesterov: + delta = beta1 * delta + norm_grad # not in-place. delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) @@ -335,7 +340,7 @@ def scaling_step(group, param, state, grad): # learning rate. scalar_scale = group["scalar_scale"] - if grad.ndim >= 2 and grad.numel() != max(grad.shape): + if grad.ndim >= 2 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): delta = cubic_decay_step(group, state, grad) else: # biases and similar-shaped tensors @@ -443,8 +448,8 @@ def __init__( self, params, lr=1.2e-02, - beta1=0.995, - cubic_decay_proportion=1.0, + beta1=0.99, + cubic_decay_proportion=0.5, beta2=0.98, eps=1.0e-08, scale_limits=(0.03, 0.15), diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index a41ba7c4e3..fdcf7cba11 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,8 +1435,8 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), lr=params.base_lr, - cubic_decay_proportion=1.0, - beta1=0.995, + cubic_decay_proportion=0.5, + beta1=0.99, ) if True: From d19196fc30497ca71b6496e4d043c084d81c55b4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 00:08:40 +0800 Subject: [PATCH 1132/1191] Change test configuration --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 223357a1be..f33a5cb5a9 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -559,8 +559,8 @@ def _test_batched_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.024 - optim = BatchedRubik(m.parameters(), lr=lr, beta1=0.999) + lr = 0.018 + optim = BatchedRubik(m.parameters(), lr=lr, beta1=0.998) num_epochs = 180 From 6c2e1d3c6566d56990f0a0b1738d32b3da64d5e5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 10:44:24 +0800 Subject: [PATCH 1133/1191] Make cubic_decay_proportion (actually scale) be rank**-0.25, as in nanochat rubik_baseline_tb_dan_largeinit_simpler25 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index f33a5cb5a9..a85b00729b 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -289,7 +289,13 @@ def cubic_decay_step(group, state, grad): if debug: moving_grad_norm = (moving_grad ** 2).mean(dim=(1,2)).sqrt() - cubic_alpha = clip_alpha(moving_grad, prod3, alpha=-(1-beta1)*cubic_decay_proportion) + rank = min(grad.shape[1], grad.shape[2]) + cubic_decay_scale = rank ** -0.25 # for large ranks we tend to get more energy concentrated in + # a smaller proportion of singular values, so the amount of decay would be more than than + # the minimum limit, this corrects for this. this formula is heuristic based on observed + # trends, not exact. + + cubic_alpha = clip_alpha(moving_grad, prod3, alpha=-(1-beta1)*cubic_decay_scale) # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) @@ -449,7 +455,6 @@ def __init__( params, lr=1.2e-02, beta1=0.99, - cubic_decay_proportion=0.5, beta2=0.98, eps=1.0e-08, scale_limits=(0.03, 0.15), @@ -462,7 +467,6 @@ def __init__( defaults = dict( lr=lr, beta1=beta1, - cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, scale_limits=scale_limits, From 7f72684669f6b71f76e7a1d197602964e1b23db8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 10:47:42 +0800 Subject: [PATCH 1134/1191] Remove cubic_decay_proportion arg from train.py. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 3 --- egs/librispeech/ASR/zapformer/train.py | 1 - 2 files changed, 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a85b00729b..a5250166f4 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -248,9 +248,6 @@ def cubic_decay_step(group, state, grad): beta2_ceil = step / (step + 1) beta2 = min(group["beta2"], beta2_ceil) - - cubic_decay_proportion = group["cubic_decay_proportion"] - orig_shape = grad.shape batch_size = orig_shape[0] rows, cols = matrix_shape(orig_shape[1:]) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index fdcf7cba11..a6551a587c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,6 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), lr=params.base_lr, - cubic_decay_proportion=0.5, beta1=0.99, ) From ff5fc52c2b32f0d045ebd611c07dae73e597cf0b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 11:11:30 +0800 Subject: [PATCH 1135/1191] Do not scale up step by more than one. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a5250166f4..8f92e04f23 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -275,14 +275,13 @@ def cubic_decay_step(group, state, grad): # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. - moving_grad.add_(norm_grad_precon) + moving_grad.add_(norm_grad_precon, alpha=(1-beta1)) # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - - debug = (step % 40 == 0) + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: moving_grad_norm = (moving_grad ** 2).mean(dim=(1,2)).sqrt() @@ -305,7 +304,7 @@ def cubic_decay_step(group, state, grad): nesterov = True if nesterov: - delta = beta1 * delta + norm_grad # not in-place. + delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) @@ -323,7 +322,12 @@ def cubic_decay_step(group, state, grad): # we ignore nesterov modification for purposes of this formula, it should make little difference anyway # if beta1 is close to 1. - delta = delta * (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) + + scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)).clamp(max=1.0) + if debug: + logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") + + delta = delta * scale ans = -lr * delta From cbb2acf382864124b854a0eeeeed7c85b644f614 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 11:39:42 +0800 Subject: [PATCH 1136/1191] Do sqrt() on the scale that normalizes the step size. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 8f92e04f23..7a8fd3dd70 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -323,7 +323,11 @@ def cubic_decay_step(group, state, grad): # if beta1 is close to 1. - scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)).clamp(max=1.0) + # doing the extra sqrt on the scale means we, in effect, half-normalize the magnitude. + # we can, I think come up with an argument that it's similar to using a different value of beta. + # (argument would require independence of grads on different steps.) + scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)).sqrt() + if debug: logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") From 3f3d732e6647a1fcc505289c08c069b548cc80f3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 13:51:25 +0800 Subject: [PATCH 1137/1191] Have cubic_alpha be computed by quadratic formula for exact decay amount. --- .../ASR/zapformer/batched_rubik.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 7a8fd3dd70..8711010752 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -161,20 +161,28 @@ def scaled_three_way_product(x): x = x * (x_meansq * max(rows, cols)) ** (-1/3) return three_way_product(x) -def clip_alpha(x: Tensor, y: Tensor, alpha: float) -> Tensor: +def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: """ - In a situation where you plan to do: - x.add_(y, alpha=alpha) - returns a possibly-modified value of alpha that - but modified to prevent divergence on x (may use an alpha closer zero if necessary) + Solve the equation: ||x + alpha y||_2^2 == ||beta x||_2^2 + + x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x + alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 + (a,b,c) = (y.y, 2 alpha x.y, x.x) + alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + # treat the thing inside the sqrt as zero if + # negative, this + # factoring out 2 from the top and bottom we get: + so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y + ... we treat the thing inside the sqrt as zero if it is negative, + which gives us the closest real solution """ - # min_sum_scale the scale beta such that (x + beta y) is minimized; x and - # y each have 2 dimensions. min_sum_scale is expected to be negative. - min_sum_scale = -(x * y).sum(dim=(1, 2), keepdim=True) / ((y ** 2).sum(dim=(1, 2), keepdim=True) + 1.0e-40) - # the safety factor of 0.5 means, don't go all the way to where the dot product of the - # change to x with x would be zero, only go some way to there. - safety_factor = 0.5 - alpha = (safety_factor * min_sum_scale).clamp(min=alpha) + eps = 1.0e-40 + xx = x.square().mean(dim=(1, 2), keepdim=True) + xy = (x * y).mean(dim=(1, 2), keepdim=True) + yy = y.square().mean(dim=(1, 2), keepdim=True) + + alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / (yy + eps) + return alpha @@ -285,13 +293,7 @@ def cubic_decay_step(group, state, grad): if debug: moving_grad_norm = (moving_grad ** 2).mean(dim=(1,2)).sqrt() - rank = min(grad.shape[1], grad.shape[2]) - cubic_decay_scale = rank ** -0.25 # for large ranks we tend to get more energy concentrated in - # a smaller proportion of singular values, so the amount of decay would be more than than - # the minimum limit, this corrects for this. this formula is heuristic based on observed - # trends, not exact. - - cubic_alpha = clip_alpha(moving_grad, prod3, alpha=-(1-beta1)*cubic_decay_scale) + cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) From 9f362a679772d6158722b522ecc0956643eca4ae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 15:40:50 +0800 Subject: [PATCH 1138/1191] Completely remove the scaling in cubic_decay_step. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 8711010752..b5541ecc46 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -308,7 +308,7 @@ def cubic_decay_step(group, state, grad): if nesterov: delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + #delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) #if True: # @@ -328,12 +328,11 @@ def cubic_decay_step(group, state, grad): # doing the extra sqrt on the scale means we, in effect, half-normalize the magnitude. # we can, I think come up with an argument that it's similar to using a different value of beta. # (argument would require independence of grads on different steps.) - scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)).sqrt() + #scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)).sqrt() - if debug: - logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") - - delta = delta * scale + #if debug: + # logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") + #delta = delta * scale ans = -lr * delta From 60bd8ec9c8d0db94c2e55acbcae5f5923d4c091f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 19:23:22 +0800 Subject: [PATCH 1139/1191] Propagate the changes from batched_rubik.py to rubik.py. --- .../ASR/zapformer/batched_rubik.py | 19 --- egs/librispeech/ASR/zapformer/rubik.py | 125 ++++++++---------- 2 files changed, 56 insertions(+), 88 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index b5541ecc46..dbac749f56 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -207,25 +207,6 @@ def matrix_shape(shape): assert False, shape -def update_halfnorm_precon(x, row_stats, col_stats, beta2, eps): - """ - half-normalize the rms of x using row-wise and column-wise stats, while - updating the moving-average stats; return the normalized x. - Shapes: - x: (batch_size, rows, cols) -row_stats: (batch_size, rows, 1) -col_stats: (batch_size, 1, cols) - Returns: - normalized x, shape: (batch_size, rows, cols) - """ - row_stats.mul_(beta2).add_(x.abs().mean(dim=2, keepdim=True), alpha=(1 - beta2)) - row_denom = (row_stats.sqrt() + eps) - x = x / row_denom - col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) - col_denom = (col_stats.sqrt() + eps) - return x / col_denom, row_denom, col_denom - - def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 08d44c8b64..e8e3ebf3cd 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -23,6 +23,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch import Tensor from torch.optim import Optimizer @@ -50,20 +51,28 @@ def scaled_three_way_product(x): x = x * (x_meansq * max(rows, cols)) ** (-1/3) return three_way_product(x) -def clip_alpha(x: Tensor, y: Tensor, alpha: float) -> Tensor: +def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: """ - In a situation where you plan to do: - x.add_(y, alpha=alpha) - returns a possibly-modified value of alpha that - but modified to prevent divergence on x (may use an alpha closer zero if necessary) + Solve the equation: ||x + alpha y||_2^2 == ||beta x||_2^2 + + x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x + alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 + (a,b,c) = (y.y, 2 alpha x.y, x.x) + alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + # treat the thing inside the sqrt as zero if + # negative, this + # factoring out 2 from the top and bottom we get: + so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y + ... we treat the thing inside the sqrt as zero if it is negative, + which gives us the closest real solution """ - # min_sum_scale the scale beta such that (x + beta y) is minimized; x and - # y each have 2 dimensions. min_sum_scale is expected to be negative. - min_sum_scale = -(x * y).sum() / ((y ** 2).sum() + 1.0e-40) - # the safety factor of 0.5 means, don't go all the way to where the dot product of the - # change to x with x would be zero, only go some way to there. - safety_factor = 0.5 - alpha = (safety_factor * min_sum_scale).clamp(min=alpha) + eps = 1.0e-40 + xx = x.square().mean() + xy = (x * y).mean() + yy = y.square().mean() + + alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / (yy + eps) + return alpha @@ -102,10 +111,13 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): Returns: normalized x, shape: (rows, cols) """ - row_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) - x = x / (row_stats.sqrt() + eps) - col_stats.mul_(beta2).add_((x ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) - return x / (col_stats.sqrt() + eps) + row_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats + eps) + x = x / row_denom + col_stats.mul_(beta2).add_(x.abs().mean(dim=0, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats + eps) + x = x / col_denom + return x, row_denom, col_denom @@ -113,14 +125,11 @@ def cubic_decay_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] - beta1_ceil = 1. - 1. / (10. + 0.1 * step) + beta1_ceil = 1. - 1. / (10. + 0.2 * step) beta1 = min(group["beta1"], beta1_ceil) beta2_ceil = step / (step + 1) beta2 = min(group["beta2"], beta2_ceil) - cubic_decay_proportion = group["cubic_decay_proportion"] - linear_decay_proportion = 1. - cubic_decay_proportion - orig_shape = grad.shape rows, cols = matrix_shape(orig_shape) grad = grad.reshape(rows, cols) @@ -135,65 +144,44 @@ def cubic_decay_step(group, state, grad): row_stats = state["row_stats"] col_stats = state["col_stats"] + # add grad again, like nesterov... just emphasize grad a bit more while also taking into account moving_grad.. + norm_grad, row_denom, col_denom = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + denom_prod = (row_denom * col_denom) + invP = denom_prod.sqrt() # this sqrt is because we only want to do half of it before and half of it after; they already had .sqrt() done to them. + norm_grad_precon = norm_grad * invP # undoes half of the normalization + # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. - moving_grad.add_(grad) - - # We'll scale both before and after the cubic decay; this can be viewed as - # doing the cubic decay in a preconditioned space where the preconditioner - # is 1 / row_col_denom. (The row and column stats will be updated later). - # Looking at this code may give the impression that we are mistakenly - # normalizing "twice". Actually we have an "equilibrium argument" why this - # is actually OK and will give correctly-normalized data. - row_denom = (row_stats.sqrt() + eps) - col_denom = (col_stats.sqrt() + eps) - invP = row_denom * col_denom # inverse preconditioner P - - moving_grad_precon = moving_grad / invP # preconditioned moving_grad - cur_grad_precon = grad / invP # this step's contribution to moving_grad_precon, used for nesterov modification - + moving_grad.add_(norm_grad_precon, alpha=(1-beta1)) + # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. - prod3 = scaled_three_way_product(moving_grad_precon) + prod3 = scaled_three_way_product(moving_grad) + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) + if debug: + moving_grad_norm = (moving_grad ** 2).mean().sqrt() - cubic_alpha = clip_alpha(moving_grad_precon, prod3, alpha=-(1-beta1)*cubic_decay_proportion) - # cubic_alpha shape: (batch_size, 1, 1). it will be negative. + cubic_alpha = compute_alpha(moving_grad, prod3, beta1) + # cubic_alpha shape: (1, 1) - linear_alpha = -(1-beta1) - cubic_alpha # will be negative. + moving_grad.add_(prod3 * cubic_alpha) - moving_grad_precon.add_(prod3 * cubic_alpha) - moving_grad_precon.mul_(1. + linear_alpha) + if debug: + moving_grad_norm_rel_change = 1. - (moving_grad ** 2).mean().sqrt() / moving_grad_norm + logging.info(f"shape={prod3.shape}, moving_grad_rel_change={moving_grad_norm_rel_change}, vs. target {(1-beta1)}") - # update moving_grad as interpolation between linear decay and cubic decay. - moving_grad[:] = moving_grad_precon * invP + delta = moving_grad / invP # re-add the half of the normalizatin that we removed nesterov = True if nesterov: - moving_grad_precon = moving_grad_precon + cur_grad_precon - - # Now compute "negative_update" which is negative_update_precon multiplied again by the - # preconditioner, this takes us from the preconditioned to the canonical co-ordinates but now treating the quantity as a parameter-update - # rather than as a gradient. it is going to be very close to: - # negative_update = moving_grad_precon / invP - # but we also update the preconditioner. Note: practically speaking we are multiplying - # by the same thing twice, i.e. dividing "grad" twice by invP. - negative_update = normalize_and_update_stats(moving_grad_precon, row_stats, col_stats, beta2, eps) - - # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what - # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. - # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style - # accumulator with beta equal to beta1. - # This should make divergence less likely. - # we ignore nesterov modification for purposes of this formula, it should make little difference anyway - # if beta1 is close to 1. - assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - negative_update = negative_update * (assumed_scale / ((negative_update ** 2).mean().sqrt() + eps)) + ans = -lr * delta + + return ans.reshape(orig_shape) - ans = -lr * negative_update - return ans.reshape(orig_shape) def scaling_step(group, param, state, grad): @@ -308,8 +296,7 @@ def __init__( self, params, lr=1.2e-02, - beta1=0.995, - cubic_decay_proportion=0.8, + beta1=0.99, beta2=0.98, eps=1.0e-08, scale_limits=(0.03, 0.15), @@ -321,7 +308,6 @@ def __init__( defaults = dict( lr=lr, beta1=beta1, - cubic_decay_proportion=cubic_decay_proportion, beta2=beta2, eps=eps, scale_limits=scale_limits, @@ -416,8 +402,8 @@ def _test_rubik(hidden_dim: int): for _ in range(20) ] - lr = 0.024 - optim = Rubik(m.parameters(), lr=lr, beta1=0.999) + lr = 0.018 + optim = Rubik(m.parameters(), lr=lr, beta1=0.998) num_epochs = 180 @@ -482,6 +468,7 @@ def lr_lambda(current_step): logging.info(f"output_magnitudes = {output_magnitudes}") + def _test_scaled_three_way_product(): x = torch.randn(16, 32) _U, _S, V = torch.linalg.svd(x, full_matrices=False) From a7bae3467d58598e0fb8a15cb18bd7f77636ba88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 21:59:00 +0800 Subject: [PATCH 1140/1191] Limit norm_grad to -3..3. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index dbac749f56..f088e886a1 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -257,6 +257,8 @@ def cubic_decay_step(group, state, grad): # add grad again, like nesterov... just emphasize grad a bit more while also taking into account moving_grad.. norm_grad, row_denom, col_denom = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + norm_grad.clamp_(min=-3, max=3) + denom_prod = (row_denom * col_denom) invP = denom_prod.sqrt() # this sqrt is because we only want to do half of it before and half of it after; they already had .sqrt() done to them. From 2422703ddd8a141180926955d12efd9cfc169484 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 May 2026 22:33:35 +0800 Subject: [PATCH 1141/1191] Remove clamping and instead fully normalize scale. --- .../ASR/zapformer/batched_rubik.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index f088e886a1..8012464a78 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -257,8 +257,6 @@ def cubic_decay_step(group, state, grad): # add grad again, like nesterov... just emphasize grad a bit more while also taking into account moving_grad.. norm_grad, row_denom, col_denom = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - norm_grad.clamp_(min=-3, max=3) - denom_prod = (row_denom * col_denom) invP = denom_prod.sqrt() # this sqrt is because we only want to do half of it before and half of it after; they already had .sqrt() done to them. @@ -273,25 +271,19 @@ def cubic_decay_step(group, state, grad): prod3 = scaled_three_way_product(moving_grad) debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) - if debug: - moving_grad_norm = (moving_grad ** 2).mean(dim=(1,2)).sqrt() cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) - if debug: - moving_grad_norm_rel_change = 1. - (moving_grad ** 2).mean(dim=(1,2)).sqrt() / moving_grad_norm - logging.info(f"shape={prod3.shape}, moving_grad_rel_change={moving_grad_norm_rel_change}, vs. target {(1-beta1)}") - delta = moving_grad / invP # re-add the half of the normalizatin that we removed nesterov = True if nesterov: delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - #delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) #if True: # @@ -311,11 +303,12 @@ def cubic_decay_step(group, state, grad): # doing the extra sqrt on the scale means we, in effect, half-normalize the magnitude. # we can, I think come up with an argument that it's similar to using a different value of beta. # (argument would require independence of grads on different steps.) - #scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)).sqrt() + + scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - #if debug: - # logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") - #delta = delta * scale + if debug: + logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") + delta = delta * scale ans = -lr * delta From aa486eeffceb8f91eb30d6a29fbc520b6078e1e2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 19 May 2026 11:00:53 +0800 Subject: [PATCH 1142/1191] Update stats twice, once after time averaging. --- .../ASR/zapformer/batched_rubik.py | 57 +++++++++++++------ egs/librispeech/ASR/zapformer/rubik.py | 55 +++++++++++++----- 2 files changed, 79 insertions(+), 33 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 8012464a78..0cc9d030bb 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -208,7 +208,7 @@ def matrix_shape(shape): -def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): +def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ Normalize the rms of x using row-wise and column-wise stats, while updating the moving-average stats; return the normalized x. @@ -224,8 +224,31 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x = x / row_denom col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats + eps) + x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() + x = x / col_denom + return x, x_half_norm + + + +def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + row_stats.mul_(beta2).add_((x ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats.sqrt() + eps) + x = x / row_denom + col_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats.sqrt() + eps) x = x / col_denom - return x, row_denom, col_denom + return x + def cubic_decay_step(group, state, grad): @@ -253,14 +276,8 @@ def cubic_decay_step(group, state, grad): row_stats = state["row_stats"] col_stats = state["col_stats"] - - # add grad again, like nesterov... just emphasize grad a bit more while also taking into account moving_grad.. - norm_grad, row_denom, col_denom = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - - denom_prod = (row_denom * col_denom) - invP = denom_prod.sqrt() # this sqrt is because we only want to do half of it before and half of it after; they already had .sqrt() done to them. - - norm_grad_precon = norm_grad * invP # undoes half of the normalization + # we half update the stats here, half update them later. + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, 0.5*(1+beta2), eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -270,21 +287,25 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) - delta = moving_grad / invP # re-add the half of the normalizatin that we removed + # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" + # will have a smaller variance than the grad itself because of being a mean over independent elements. + # we rescale before getting the stats, to have the same variance as if it were the grad. + # The actual variance of moving_grad also depends on the variance of the original grads; this is just + # a scalar component in the variance to accountn for averaging-over-time effects. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, 0.5*(1+beta2), eps) nesterov = True if nesterov: delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - delta_assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - #if True: # #if step < 5 or (step < 500 and step % 10 == 0): @@ -304,11 +325,11 @@ def cubic_decay_step(group, state, grad): # we can, I think come up with an argument that it's similar to using a different value of beta. # (argument would require independence of grads on different steps.) - scale = (delta_assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: - logging.info(f"shape={prod3.shape}, scale={scale.flatten()}") - delta = delta * scale + scale = (assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) + logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied]") + #delta = delta * scale ans = -lr * delta diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index e8e3ebf3cd..4846195ab4 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -100,7 +100,7 @@ def matrix_shape(shape): assert False, shape -def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): +def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ Normalize the rms of x using row-wise and column-wise stats, while updating the moving-average stats; return the normalized x. @@ -116,8 +116,31 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x = x / row_denom col_stats.mul_(beta2).add_(x.abs().mean(dim=0, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats + eps) + x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() x = x / col_denom - return x, row_denom, col_denom + return x, x_half_norm + + + +def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + row_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats.sqrt() + eps) + x = x / row_denom + col_stats.mul_(beta2).add_((x ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats.sqrt() + eps) + x = x / col_denom + return x + @@ -144,11 +167,8 @@ def cubic_decay_step(group, state, grad): row_stats = state["row_stats"] col_stats = state["col_stats"] - # add grad again, like nesterov... just emphasize grad a bit more while also taking into account moving_grad.. - norm_grad, row_denom, col_denom = normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - denom_prod = (row_denom * col_denom) - invP = denom_prod.sqrt() # this sqrt is because we only want to do half of it before and half of it after; they already had .sqrt() done to them. - norm_grad_precon = norm_grad * invP # undoes half of the normalization + # we half update the stats here, half update them later. + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, 0.5*(1+beta2), eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -158,25 +178,30 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) - if debug: - moving_grad_norm = (moving_grad ** 2).mean().sqrt() cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: (1, 1) moving_grad.add_(prod3 * cubic_alpha) - if debug: - moving_grad_norm_rel_change = 1. - (moving_grad ** 2).mean().sqrt() / moving_grad_norm - logging.info(f"shape={prod3.shape}, moving_grad_rel_change={moving_grad_norm_rel_change}, vs. target {(1-beta1)}") - - delta = moving_grad / invP # re-add the half of the normalizatin that we removed + # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" + # will have a smaller variance than the grad itself because of being a mean over independent elements. + # we rescale before getting the stats, to have the same variance as if it were the grad. + # The actual variance of moving_grad also depends on the variance of the original grads; this is just + # a scalar component in the variance to accountn for averaging-over-time effects. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, 0.5*(1+beta2), eps) nesterov = True if nesterov: delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) + if debug: + scale = (assumed_scale / ((delta ** 2).mean().sqrt() + eps)) + logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied]") + #delta = delta * scale + ans = -lr * delta return ans.reshape(orig_shape) From 5e18b1e97b6f352f360eae0c3fdd80f46fa3cc55 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 19 May 2026 11:20:03 +0800 Subject: [PATCH 1143/1191] Remove unnecessary dist reduce --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 0cc9d030bb..79c9ed9d97 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -import torch.distributed as dist from torch import Tensor from torch.optim import Optimizer @@ -504,9 +503,6 @@ def step(self, closure=None): for p, state in batches: grad = p.grad - if dist.is_initialized(): - dist.all_reduce(grad, op=dist.ReduceOp.AVG) - try: cur_step = state["step"] except KeyError: From 8e8282434fce5d2bb6bbfc4840ddd25f3ce08d22 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 19 May 2026 13:10:52 +0800 Subject: [PATCH 1144/1191] Introduce beta2b_scale=0.1 to make stats dominated by grad not moving_grad. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 7 +++++-- egs/librispeech/ASR/zapformer/rubik.py | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 79c9ed9d97..9df1cfefb3 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -276,7 +276,7 @@ def cubic_decay_step(group, state, grad): col_stats = state["col_stats"] # we half update the stats here, half update them later. - norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, 0.5*(1+beta2), eps) + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -299,7 +299,10 @@ def cubic_decay_step(group, state, grad): # a scalar component in the variance to accountn for averaging-over-time effects. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, 0.5*(1+beta2), eps) + # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will # make the stats update more dominated by grad rather than moving_grad. + beta2b_scale = 0.1 + beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2b, eps) nesterov = True if nesterov: diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 4846195ab4..1b903bde84 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -168,7 +168,7 @@ def cubic_decay_step(group, state, grad): col_stats = state["col_stats"] # we half update the stats here, half update them later. - norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, 0.5*(1+beta2), eps) + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -190,7 +190,11 @@ def cubic_decay_step(group, state, grad): # The actual variance of moving_grad also depends on the variance of the original grads; this is just # a scalar component in the variance to accountn for averaging-over-time effects. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, 0.5*(1+beta2), eps) + + # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will # make the stats update more dominated by grad rather than moving_grad. + beta2b_scale = 0.1 + beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2b, eps) nesterov = True if nesterov: From c2fc785dd46cc152db3611cf870d49f8d3e3b44a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 20 May 2026 14:14:36 +0800 Subject: [PATCH 1145/1191] take changes from nanochat setup for memory usage. --- .../ASR/zapformer/batched_rubik.py | 92 +++++++++---------- 1 file changed, 45 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 9df1cfefb3..4f1f8c0b8f 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -23,6 +23,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch import Tensor from torch.optim import Optimizer @@ -83,12 +84,27 @@ def batched_params(self, param_list): batches[key].append(p) + old_batches = batches.values() # a list of lists + # Now split up any batches that are too large. + batches = [ ] + for b in old_batches: + num_tensors = len(b) + num_bytes = num_tensors * b[0].nbytes # total bytes in group of tensors + max_bytes = 2 ** 30 # 1024**3 == one gigabyte + num_groups = min(num_tensors, (num_bytes + max_bytes - 1) // max_bytes) + group_size = (num_tensors + num_groups - 1) // num_groups + tot = 0 + for g in range(num_groups): + batches.append(b[g*group_size:(g+1)*group_size]) + tot += len(batches[-1]) + assert tot == num_tensors + # tuples will contain tuples of (stacked_param, state), # one for each batch in `batches`. tuples = [] - for batch in batches.values(): + for batch in batches: p = batch[0] # we arbitrarily store the state in the # state corresponding to the 1st parameter in the @@ -103,50 +119,22 @@ def batched_params(self, param_list): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state), batch) in zip(tuples, batches.values()): + for ((stacked_params, _state), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) -def compute_prod3(x): - assert x.ndim >= 2 +def three_way_product(x): + """ returns the 3-way matrix product x @ x.t() @ x """ if x.shape[-2] <= x.shape[-1]: - x2 = torch.matmul(x, x.transpose(-2, -1)) + x2 = torch.matmul(x, x.mT) return torch.matmul(x2, x) else: - x2 = torch.matmul(x.transpose(-2, -1), x) + x2 = torch.matmul(x.mT, x) return torch.matmul(x, x2) -def _three_way_product_chunk(x_chunk): - """Core computation: x_chunk @ x_chunk.T @ x_chunk for a single chunk.""" - if x_chunk.shape[-2] <= x_chunk.shape[-1]: - x2 = torch.matmul(x_chunk, x_chunk.transpose(-2, -1)) - return torch.matmul(x2, x_chunk) - else: - x2 = torch.matmul(x_chunk.transpose(-2, -1), x_chunk) - return torch.matmul(x_chunk, x2) - - -def three_way_product(x, chunk_size=32): - """ returns the 3-way matrix product x @ x.t() @ x - - Processes the batch dimension in chunks to reduce peak GPU memory usage. - The intermediate x @ x.T has shape (batch, rows, rows) which can be very - large; chunking keeps peak memory proportional to chunk_size instead of batch. - """ - assert x.ndim >= 2 - batch = x.shape[0] - if batch <= chunk_size: - return _three_way_product_chunk(x) - results = [] - for start in range(0, batch, chunk_size): - end = min(start + chunk_size, batch) - results.append(_three_way_product_chunk(x[start:end])) - return torch.cat(results, dim=0) - - def scaled_three_way_product(x): """ Returns alpha * (x @ x.t() @ x), @@ -167,16 +155,16 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 (a,b,c) = (y.y, 2 alpha x.y, x.x) - alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. # treat the thing inside the sqrt as zero if - # negative, this + # negative, this # factoring out 2 from the top and bottom we get: so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y ... we treat the thing inside the sqrt as zero if it is negative, which gives us the closest real solution """ eps = 1.0e-40 - xx = x.square().mean(dim=(1, 2), keepdim=True) + xx = x.square().mean(dim=(1, 2), keepdim=True) xy = (x * y).mean(dim=(1, 2), keepdim=True) yy = y.square().mean(dim=(1, 2), keepdim=True) @@ -226,8 +214,8 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() x = x / col_denom return x, x_half_norm - - + + def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ @@ -240,10 +228,11 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): Returns: normalized x, shape: (batch_size, rows, cols) """ - row_stats.mul_(beta2).add_((x ** 2).mean(dim=2, keepdim=True), alpha=(1 - beta2)) + # use squared norm to save memory + row_stats.mul_(beta2).add_(x.square().mean(dim=2, keepdim=True), alpha=(1 - beta2)) row_denom = (row_stats.sqrt() + eps) x = x / row_denom - col_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + col_stats.mul_(beta2).add_(x.square().mean(dim=1, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats.sqrt() + eps) x = x / col_denom return x @@ -299,14 +288,16 @@ def cubic_decay_step(group, state, grad): # a scalar component in the variance to accountn for averaging-over-time effects. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will # make the stats update more dominated by grad rather than moving_grad. + + # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will + # make the stats update more dominated by grad rather than moving_grad. beta2b_scale = 0.1 beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2b, eps) nesterov = True if nesterov: - delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. + delta.lerp_(norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. #if True: # @@ -331,11 +322,10 @@ def cubic_decay_step(group, state, grad): if debug: scale = (assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied]") - #delta = delta * scale - ans = -lr * delta + delta.mul_(-lr) - return ans.reshape(orig_shape) + return delta.reshape(orig_shape) def scaling_step(group, param, state, grad): @@ -454,6 +444,8 @@ class BatchedRubik(BatchedOptimizer): scale_default: A constant that dictates the RMS value to which weight magnitudes decay. scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. eps: A general-purpose epsilon to prevent division by zero +grad_aggregation: if None, no grad aggregation is done here (assume it is done in DDP if relevant); + set it to torch.distributed.ReduceOp.AVG or torch.distributed.ReduceOp.SUM to have it done by this class. """ def __init__( self, @@ -467,8 +459,9 @@ def __init__( adam_beta1=0.98, adam_beta2=0.98, scale_momentum=0.95, + grad_aggregation=None, ): - + self.grad_aggregation = grad_aggregation defaults = dict( lr=lr, beta1=beta1, @@ -506,6 +499,11 @@ def step(self, closure=None): for p, state in batches: grad = p.grad + if self.grad_aggregation is not None and dist.is_initialized(): + # sync grads. + dist.all_reduce(grad, op=self.grad_aggregation) + + try: cur_step = state["step"] except KeyError: From 26bdc2a9618da15c6fc1c47c401218810c318b89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 20 May 2026 14:16:11 +0800 Subject: [PATCH 1146/1191] Remove beta2b, use just beta. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 4f1f8c0b8f..4594056d75 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -288,12 +288,7 @@ def cubic_decay_step(group, state, grad): # a scalar component in the variance to accountn for averaging-over-time effects. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - - # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will - # make the stats update more dominated by grad rather than moving_grad. - beta2b_scale = 0.1 - beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) - delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2b, eps) + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2, eps) nesterov = True if nesterov: From fd5a11670f83a09cbfbf75fbddb4f1eb96e9d811 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 20 May 2026 21:48:21 +0800 Subject: [PATCH 1147/1191] Extra printout, alpha_ratio --- egs/librispeech/ASR/zapformer/batched_rubik.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 4594056d75..190c201d90 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -315,8 +315,9 @@ def cubic_decay_step(group, state, grad): debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: + cubic_alpha_ratio = -cubic_alpha / (1-beta1) scale = (assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied]") + logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied], alpha_ratio={cubic_alpha_ratio.flatten()}") delta.mul_(-lr) From 9de596b03f3609d610034f93e16ff6cbb2e656f8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 20 May 2026 22:36:31 +0800 Subject: [PATCH 1148/1191] Remove warmup for second beta2. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 190c201d90..a4f0df3168 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -288,7 +288,9 @@ def cubic_decay_step(group, state, grad): # a scalar component in the variance to accountn for averaging-over-time effects. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2, eps) + # use the original beta2, not the reduced one, for this step. + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, + group["beta2"], eps) nesterov = True if nesterov: @@ -317,7 +319,7 @@ def cubic_decay_step(group, state, grad): if debug: cubic_alpha_ratio = -cubic_alpha / (1-beta1) scale = (assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied], alpha_ratio={cubic_alpha_ratio.flatten()}") + logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied], alpha_ratio={cubic_alpha_ratio.flatten()}, delta-max={delta.abs().max(dim=1)[0].max(dim=1)[0]}") delta.mul_(-lr) From 76d93f5e1e55069d6b14375143af357a1fe7abc5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 20 May 2026 23:15:46 +0800 Subject: [PATCH 1149/1191] Clamp to -4..4. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a4f0df3168..089b3f6ae1 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -296,6 +296,10 @@ def cubic_decay_step(group, state, grad): if nesterov: delta.lerp_(norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. + + # try to prevent divergence at the start. + delta.clamp_(min=-4, max=4) + #if True: # #if step < 5 or (step < 500 and step % 10 == 0): @@ -315,6 +319,7 @@ def cubic_decay_step(group, state, grad): # we can, I think come up with an argument that it's similar to using a different value of beta. # (argument would require independence of grads on different steps.) + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: cubic_alpha_ratio = -cubic_alpha / (1-beta1) From b2908cdd235a8d404edd15a39f8b9413497bbb73 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 13:06:12 +0800 Subject: [PATCH 1150/1191] Use invP as preconditioning for compute_alpha invocation. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 089b3f6ae1..a2632ff7ba 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -211,9 +211,12 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x = x / row_denom col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats + eps) - x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() + row_denom_sqrt = row_denom.sqrt() + col_denom_sqrt = col_denom.sqrt() + x_half_norm = (x * row_denom_sqrt) / col_denom_sqrt x = x / col_denom - return x, x_half_norm + invP = row_denom * col_denom + return x, x_half_norm, invP @@ -265,7 +268,7 @@ def cubic_decay_step(group, state, grad): col_stats = state["col_stats"] # we half update the stats here, half update them later. - norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + norm_grad, norm_grad_precon, invP = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -275,8 +278,10 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - - cubic_alpha = compute_alpha(moving_grad, prod3, beta1) + # dividing the following by invP means we are using 1 / invP as a scale for computing + # norms, as if we were to compute the norm of delta ~= moving_grad / invP after doing + # moving_grad.add_(prod3 * cubic_alpha). + cubic_alpha = compute_alpha(moving_grad / invP, prod3 / invP, beta1) # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) From 10ac67a3131dc24f987e75702d762a470112f825 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 12:08:20 +0800 Subject: [PATCH 1151/1191] Increase weight_rms in OrthogonalLinear to slow down learning of projections; remove projection_overlap loss. --- egs/librispeech/ASR/zapformer/zapformer.py | 18 +++++++++++------- .../ASR/zapformer/zapformer_modules.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index f6bfe50696..f6fdf330aa 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -285,22 +285,26 @@ def forward( x = self.out_norm(x) - if self.training: - # all of our losses and aux losses are proportional to the number of frames of data, so - # we multiply by that factor. - x = with_loss(x, aux_loss_scale * x.shape[0] * x.shape[1] * self.compute_projection_overlap()) + # disable the projection-overlap loss + #if self.training: + # # all of our losses and aux losses are proportional to the number of frames of data, so + # # we multiply by that factor. + # x = with_loss(x, aux_loss_scale * x.shape[0] * x.shape[1] * self.compute_projection_overlap()) return x, x_lens def compute_projection_overlap(self, verbose: bool = False): - # This computes a quantity that we'll use as an auxiliary loss. - # It ensures that the projections from more-subsampled sequences "contain" enough of the + # This is currently just used for some diagnostics. + + # It also computes an auxiliary loss (currently unused) that + # ensures that the projections from more-subsampled sequences "contain" enough of the # projections from the less-subsampled sequences-- specifically the direction where # all the less-subsampled projections co-vary in the same way, e.g. if there are # two frames, that the two frames are identical. - min_overlap = 0.7 # we can tune this + min_overlap = 0.6 # we can tune this. CAUTION: I turned off this aux loss by commenting + # it out in forward(), tot_loss = 0.0 # between pairs of encoders diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index f7845ea94c..adad0fc5bc 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -528,7 +528,7 @@ class OrthogonalLinear(nn.Linear): def __init__(self, in_channels: int, out_channels: int, - weight_rms: float = 0.2, + weight_rms: float = 0.3, bias: bool = True, penalty_scale: float = 20.0, ): From 44ddf89c3ea2808f8da24dfab3afdb847731acf4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 13:39:52 +0800 Subject: [PATCH 1152/1191] Go back to using beta2b, like reverse of 3178, --- egs/librispeech/ASR/zapformer/batched_rubik.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a2632ff7ba..c7d6edfb55 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -293,9 +293,12 @@ def cubic_decay_step(group, state, grad): # a scalar component in the variance to accountn for averaging-over-time effects. assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - # use the original beta2, not the reduced one, for this step. + # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will + # make the stats update more dominated by grad rather than moving_grad. + beta2b_scale = 0.1 + beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, - group["beta2"], eps) + beta2b, eps) nesterov = True if nesterov: From feb7711d3dda04b4f43468fc048edb5a67a09cb8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 13:56:14 +0800 Subject: [PATCH 1153/1191] Do conventional beta1 decay for first 200 steps. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a2632ff7ba..aca35a7b31 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -286,6 +286,10 @@ def cubic_decay_step(group, state, grad): moving_grad.add_(prod3 * cubic_alpha) + if step < 200: + # to avoid divergence at the start, do normal decay for the first 200 steps. + moving_grad.mul_(beta1) + # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" # will have a smaller variance than the grad itself because of being a mean over independent elements. # we rescale before getting the stats, to have the same variance as if it were the grad. From 9a348d63dda6b512557f3052a9549c906cc8e9f8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 14:23:09 +0800 Subject: [PATCH 1154/1191] Clamp zapformer layer output to -5..5 to prevent divergence early on with fp16. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ---- egs/librispeech/ASR/zapformer/zapformer.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index aca35a7b31..a2632ff7ba 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -286,10 +286,6 @@ def cubic_decay_step(group, state, grad): moving_grad.add_(prod3 * cubic_alpha) - if step < 200: - # to avoid divergence at the start, do normal decay for the first 200 steps. - moving_grad.mul_(beta1) - # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" # will have a smaller variance than the grad itself because of being a mean over independent elements. # we rescale before getting the stats, to have the same variance as if it were the grad. diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index f6fdf330aa..f3b7679999 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -680,7 +680,7 @@ def forward( key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) - src = src + self.conv_module(3. * src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) @@ -693,6 +693,8 @@ def forward( src = self.norm(src, src_key_padding_mask) + src = src.clamp(min=-5, max=5) + return src def streaming_forward( @@ -782,6 +784,8 @@ def streaming_forward( cached_len=cached_norm_len, ) + src = src.clamp(min=-5, max=5) + return ( src, cached_key, @@ -2036,7 +2040,7 @@ def __init__( bottleneck_dim, channels, activation="SwashR", - initial_scale=0.05, + initial_scale=0.2, ) def forward( @@ -2056,8 +2060,8 @@ def forward( Returns: Tensor: Output tensor (#time, batch, channels). """ - - x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) + input_scale = 3. + x = self.in_proj(x * input_scale) # (time, batch, 3*bottleneck_dim) x, y = x.chunk(2, dim=2) y = self.sigmoid(y) From aac8998b21894a18f28488e07b290ced5f54568b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 14:25:18 +0800 Subject: [PATCH 1155/1191] Finish changing input_scale of conv module to be inside conv_module. --- egs/librispeech/ASR/zapformer/zapformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index f3b7679999..5e4587ed84 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -763,7 +763,7 @@ def streaming_forward( src = src + self_attn_out src_conv, cached_conv, cached_conv_wm_sum, cached_conv_wm_num_frames = self.conv_module.streaming_forward( - 3.0 * src, + src, cached_conv=cached_conv, cached_wm_sum=cached_conv_wm_sum, cached_wm_num_frames=cached_conv_wm_num_frames, @@ -2040,7 +2040,7 @@ def __init__( bottleneck_dim, channels, activation="SwashR", - initial_scale=0.2, + initial_scale=0.05, ) def forward( @@ -2122,7 +2122,8 @@ def streaming_forward( - Updated cached_wm_sum (1, batch, channels) - Updated cached_wm_num_frames (batch,) """ - x = self.in_proj(x) # (time, batch, 3*bottleneck_dim) + input_scale = 3. + x = self.in_proj(x * input_scale) # (time, batch, 3*bottleneck_dim) x, y = x.chunk(2, dim=2) y = self.sigmoid(y) From d4d81a07749bba224dc0c30130ce0baa6916e069 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 12:08:20 +0800 Subject: [PATCH 1156/1191] Propagate recent changes to batched_rubik.py to rubik.py --- egs/librispeech/ASR/zapformer/rubik.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 1b903bde84..9336373ecf 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -116,9 +116,12 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x = x / row_denom col_stats.mul_(beta2).add_(x.abs().mean(dim=0, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats + eps) - x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() + row_denom_sqrt = row_denom.sqrt() + col_denom_sqrt = col_denom.sqrt() + x_half_norm = (x * row_denom_sqrt) / col_denom_sqrt x = x / col_denom - return x, x_half_norm + invP = row_denom * col_denom + return x, x_half_norm, invP @@ -168,7 +171,7 @@ def cubic_decay_step(group, state, grad): col_stats = state["col_stats"] # we half update the stats here, half update them later. - norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + norm_grad, norm_grad_precon, invP = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -179,8 +182,11 @@ def cubic_decay_step(group, state, grad): prod3 = scaled_three_way_product(moving_grad) - cubic_alpha = compute_alpha(moving_grad, prod3, beta1) - # cubic_alpha shape: (1, 1) + # dividing the following by invP means we are using 1 / invP as a scale for computing + # norms, as if we were to compute the norm of delta ~= moving_grad / invP after doing + # moving_grad.add_(prod3 * cubic_alpha). + cubic_alpha = compute_alpha(moving_grad / invP, prod3 / invP, beta1) + # cubic_alpha shape: scalar moving_grad.add_(prod3 * cubic_alpha) @@ -200,6 +206,9 @@ def cubic_decay_step(group, state, grad): if nesterov: delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. + # try to prevent divergence at the start. + delta.clamp_(min=-4, max=4) + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: scale = (assumed_scale / ((delta ** 2).mean().sqrt() + eps)) From b8347d3d45b23a4c35434eaea4f8578d5fd6134d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 20:14:01 +0800 Subject: [PATCH 1157/1191] Introduce safety_factor=0.5 in alpha computation. --- .../ASR/zapformer/batched_rubik.py | 20 ++++++++++---- egs/librispeech/ASR/zapformer/rubik.py | 27 +++++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index a2632ff7ba..0549fcc66e 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -89,7 +89,7 @@ def batched_params(self, param_list): batches = [ ] for b in old_batches: num_tensors = len(b) - num_bytes = num_tensors * b[0].nbytes # total bytes in group of tensors + num_bytes = num_tensors * b[0].numel() * 4 # total bytes in group of tensors, assuming float max_bytes = 2 ** 30 # 1024**3 == one gigabyte num_groups = min(num_tensors, (num_bytes + max_bytes - 1) // max_bytes) group_size = (num_tensors + num_groups - 1) // num_groups @@ -128,10 +128,10 @@ def batched_params(self, param_list): def three_way_product(x): """ returns the 3-way matrix product x @ x.t() @ x """ if x.shape[-2] <= x.shape[-1]: - x2 = torch.matmul(x, x.mT) + x2 = torch.matmul(x, x.transpose(-2, -1)) return torch.matmul(x2, x) else: - x2 = torch.matmul(x.mT, x) + x2 = torch.matmul(x.transpose(-2, -1), x) return torch.matmul(x, x2) @@ -167,10 +167,20 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: xx = x.square().mean(dim=(1, 2), keepdim=True) xy = (x * y).mean(dim=(1, 2), keepdim=True) yy = y.square().mean(dim=(1, 2), keepdim=True) + yyeps = yy + eps - alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / (yy + eps) + # this alpha is the value that solves exactly for the requested difference in norm. + # this will be negative. + alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps - return alpha + # min_sum_scale is the value of alpha that would minimize the norm of a + alpha y. + min_sum_scale = -xy / yyeps + # safety_factor = 0.5 means we are only willing to go halfway to that value that minimizes the norm, + # to avoid change of eigenvalue sign / overshoot, which can ultimately lead to certain + # parameter eigenvalues getting too large. + safety_factor = 0.5 + + return torch.maximum(safety_factor * min_sum_scale, alpha) # return the closet to zero of these two formulae. def matrix_shape(shape): diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 1b903bde84..e610a1abd4 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -58,9 +58,9 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 (a,b,c) = (y.y, 2 alpha x.y, x.x) - alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. # treat the thing inside the sqrt as zero if - # negative, this + # negative, this # factoring out 2 from the top and bottom we get: so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y ... we treat the thing inside the sqrt as zero if it is negative, @@ -70,10 +70,21 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: xx = x.square().mean() xy = (x * y).mean() yy = y.square().mean() + yyeps = yy + eps - alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / (yy + eps) + # this alpha is the value that solves exactly for the requested difference in norm. + # this will be negative. + alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps + + # min_sum_scale is the value of alpha that would minimize the norm of a + alpha y. + min_sum_scale = -xy / yyeps + # safety_factor = 0.5 means we are only willing to go halfway to that value that minimizes the norm, + # to avoid change of eigenvalue sign / overshoot, which can ultimately lead to certain + # parameter eigenvalues getting too large. + safety_factor = 0.5 + + return torch.maximum(safety_factor * min_sum_scale, alpha) # return the closet to zero of these two formulae. - return alpha @@ -119,8 +130,8 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() x = x / col_denom return x, x_half_norm - - + + def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ @@ -173,7 +184,7 @@ def cubic_decay_step(group, state, grad): # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. moving_grad.add_(norm_grad_precon, alpha=(1-beta1)) - + # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) @@ -207,7 +218,7 @@ def cubic_decay_step(group, state, grad): #delta = delta * scale ans = -lr * delta - + return ans.reshape(orig_shape) From f7148d09726133c274085be8e9173adf84506159 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 21:10:52 +0800 Subject: [PATCH 1158/1191] fix merge issue in rubik.py --- egs/librispeech/ASR/zapformer/rubik.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index ff43fb0efd..85ee13aabe 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -135,8 +135,6 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): return x, x_half_norm, invP ->>>>>>> deterministic_invertible3187conv - def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): """ Normalize the rms of x using row-wise and column-wise stats, while From 69da97a16c5d9ada0ec024cd8f604c82a8d1146b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 21 May 2026 22:38:00 +0800 Subject: [PATCH 1159/1191] remove clamp(-4,4) from rubik. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 -- egs/librispeech/ASR/zapformer/rubik.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 78fa277ad4..cb7fc81c99 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -315,8 +315,6 @@ def cubic_decay_step(group, state, grad): delta.lerp_(norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - # try to prevent divergence at the start. - delta.clamp_(min=-4, max=4) #if True: # diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 85ee13aabe..bea26519d1 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -216,9 +216,6 @@ def cubic_decay_step(group, state, grad): if nesterov: delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - # try to prevent divergence at the start. - delta.clamp_(min=-4, max=4) - debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: scale = (assumed_scale / ((delta ** 2).mean().sqrt() + eps)) From 52593f85eb4a4597b6019762912cfedda9345aac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 22 May 2026 14:51:42 +0800 Subject: [PATCH 1160/1191] Copy muon-core code from rubik_baseline_tb_dan_largeinit_simpler42 and fix bug with dtype mismatches. --- .../ASR/zapformer/batched_rubik.py | 160 +++++++++--------- 1 file changed, 83 insertions(+), 77 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index cb7fc81c99..62338c7fff 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -27,6 +27,13 @@ from torch import Tensor from torch.optim import Optimizer +#try: +# from nanochat.common import print0 +# from nanochat.common import COMPUTE_DTYPE +#except: +#from logging import info as print0 +#COMPUTE_DTYPE = torch.float32 +COMPUTE_DTYPE = torch.bfloat16 class BatchedOptimizer(Optimizer): @@ -252,7 +259,64 @@ def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): -def cubic_decay_step(group, state, grad): +# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2) +# From https://arxiv.org/pdf/2505.16932 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + +#@torch.compile(dynamic=False, fullgraph=True) +def muon_step_fused( + stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients + momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer + second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment + momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment + eps: Tensor, + ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations + red_dim: int, # -1 or -2 - reduction dimension for variance +) -> Tensor: + """ + Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update + All in one compiled graph to eliminate Python overhead between ops. + Some of the constants are 0-D CPU tensors to avoid recompilation when values change. + """ + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # Polar express + # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) + X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) + if g.size(-2) > g.size(-1): # Tall matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: # Wide matrix (original math) + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X + + # Variance normalization + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2_t) + g = g / (second_momentum_buffer.sqrt() + eps).to(g.dtype) + lr = lr_t.to(g.dtype) + return -lr * g + + +def muon_core_step(group, state, grad): lr = group["lr"] eps = group["eps"] step = state["step"] @@ -266,85 +330,27 @@ def cubic_decay_step(group, state, grad): rows, cols = matrix_shape(orig_shape[1:]) grad = grad.reshape(batch_size, rows, cols) - if "moving_grad" not in state: + if "momentum_buffer" not in state: assert step < 2 - state["moving_grad"] = torch.zeros(batch_size, rows, cols, device=grad.device) - state["row_stats"] = torch.ones(batch_size, rows, 1, device=grad.device) - state["col_stats"] = torch.ones(batch_size, 1, cols, device=grad.device) - - - moving_grad = state["moving_grad"] - row_stats = state["row_stats"] - col_stats = state["col_stats"] - - # we half update the stats here, half update them later. - norm_grad, norm_grad_precon, invP = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) - - # add the grad to the moving-average grad; the scaling factor used here - # doesn't matter as it all gets normalized later. - moving_grad.add_(norm_grad_precon, alpha=(1-beta1)) - - # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were - # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. - prod3 = scaled_three_way_product(moving_grad) - - # dividing the following by invP means we are using 1 / invP as a scale for computing - # norms, as if we were to compute the norm of delta ~= moving_grad / invP after doing - # moving_grad.add_(prod3 * cubic_alpha). - cubic_alpha = compute_alpha(moving_grad / invP, prod3 / invP, beta1) - # cubic_alpha shape: (batch_size, 1, 1) - - moving_grad.add_(prod3 * cubic_alpha) - - # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" - # will have a smaller variance than the grad itself because of being a mean over independent elements. - # we rescale before getting the stats, to have the same variance as if it were the grad. - # The actual variance of moving_grad also depends on the variance of the original grads; this is just - # a scalar component in the variance to accountn for averaging-over-time effects. - assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) - - # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will - # make the stats update more dominated by grad rather than moving_grad. - beta2b_scale = 0.1 - beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) - delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, - beta2b, eps) - - nesterov = True - if nesterov: - delta.lerp_(norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. - - - - #if True: - # - #if step < 5 or (step < 500 and step % 10 == 0): - #logging.info(f"shape={delta.shape}, grad rms is {(grad ** 2).mean(dim=(1,2)).sqrt()}, norm_grad rms is {(norm_grad ** 2).mean(dim=(1,2)).sqrt()}, norm_grad_precon rms is {(norm_grad_precon ** 2).mean(dim=(1,2)).sqrt()}, delta rms is {(delta ** 2).mean(dim=(1,2)).sqrt()}, moving_grad rms is {(moving_grad ** 2).mean(dim=(1,2)).sqrt()}, row_stats_sqrt rms is {row_stats.sqrt().mean(dim=(1,2))}, col_stats sqrt rms is {col_stats.sqrt().mean(dim=(1,2))}") - - - # do "immediate" normalization of 2-norm of the step to make the overall scale of the update what - # it would be if this was a normal decaying-beta1 update and the stats were i.i.d.. - # below is the assumed scale of d if stats were i.i.d. and this were a more normal adam-style - # accumulator with beta equal to beta1. - # This should make divergence less likely. - # we ignore nesterov modification for purposes of this formula, it should make little difference anyway - # if beta1 is close to 1. - + state["momentum_buffer"] = torch.zeros(batch_size, rows, cols, device=grad.device, dtype=COMPUTE_DTYPE) + if rows > cols: + state["second_momentum_buffer"] = torch.zeros(batch_size, rows, 1, device=grad.device, dtype=torch.float) + else: + state["second_momentum_buffer"] = torch.zeros(batch_size, 1, cols, device=grad.device, dtype=torch.float) - # doing the extra sqrt on the scale means we, in effect, half-normalize the magnitude. - # we can, I think come up with an argument that it's similar to using a different value of beta. - # (argument would require independence of grads on different steps.) + momentum_buffer = state["momentum_buffer"] + second_momentum_buffer = state["second_momentum_buffer"] - debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) - if debug: - cubic_alpha_ratio = -cubic_alpha / (1-beta1) - scale = (assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) - logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied], alpha_ratio={cubic_alpha_ratio.flatten()}, delta-max={delta.abs().max(dim=1)[0].max(dim=1)[0]}") + def t(x): + return torch.tensor(x, device=grad.device, dtype=COMPUTE_DTYPE) + def tf(x): + return torch.tensor(x, device=grad.device, dtype=torch.float) - delta.mul_(-lr) + step = muon_step_fused(grad.to(COMPUTE_DTYPE), momentum_buffer, second_momentum_buffer, + t(beta1), t(lr), tf(beta2), t(eps), 5, (-1 if rows > cols else -2)) - return delta.reshape(orig_shape) + return step.reshape(orig_shape) def scaling_step(group, param, state, grad): @@ -361,7 +367,7 @@ def scaling_step(group, param, state, grad): scalar_scale = group["scalar_scale"] if grad.ndim >= 2 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): - delta = cubic_decay_step(group, state, grad) + delta = muon_core_step(group, state, grad) else: # biases and similar-shaped tensors delta = adam_step(group, state, grad) @@ -417,8 +423,8 @@ def adam_step(group, state, grad): exp_avg_sq = state["exp_avg_sq"] except KeyError as e: assert step < 2 - exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) - exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=COMPUTE_DTYPE) + exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=COMPUTE_DTYPE) state["exp_avg"] = exp_avg state["exp_avg_sq"] = exp_avg_sq From 903576b26c52f24026ab4688a70ccbc78624be86 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 22 May 2026 17:19:44 +0800 Subject: [PATCH 1161/1191] Take muon-core rubik from rubik_baseline_tb_dan_largeinit_simpler45. --- .../ASR/zapformer/batched_rubik.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 62338c7fff..cf4cdc8851 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -27,12 +27,12 @@ from torch import Tensor from torch.optim import Optimizer -#try: -# from nanochat.common import print0 -# from nanochat.common import COMPUTE_DTYPE -#except: -#from logging import info as print0 -#COMPUTE_DTYPE = torch.float32 +# try: +# from nanochat.common import print0 +# from nanochat.common import COMPUTE_DTYPE +# except: +# from logging import info as print0 +# #COMPUTE_DTYPE = torch.float32 COMPUTE_DTYPE = torch.bfloat16 @@ -309,11 +309,15 @@ def muon_step_fused( g = X # Variance normalization + beta2 = beta2_t.to(second_momentum_buffer.dtype) v_mean = g.float().square().mean(dim=red_dim, keepdim=True) - second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2_t) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) g = g / (second_momentum_buffer.sqrt() + eps).to(g.dtype) lr = lr_t.to(g.dtype) - return -lr * g + beta1 = momentum_t.to(g.dtype) + # assumed scale of step size if it arose from momentum decay from i.i.d. variance-1 grads. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + return -lr * assumed_scale * g def muon_core_step(group, state, grad): @@ -344,11 +348,9 @@ def muon_core_step(group, state, grad): def t(x): return torch.tensor(x, device=grad.device, dtype=COMPUTE_DTYPE) - def tf(x): - return torch.tensor(x, device=grad.device, dtype=torch.float) step = muon_step_fused(grad.to(COMPUTE_DTYPE), momentum_buffer, second_momentum_buffer, - t(beta1), t(lr), tf(beta2), t(eps), 5, (-1 if rows > cols else -2)) + t(beta1), t(lr), t(beta2), t(eps), 5, (-1 if rows > cols else -2)) return step.reshape(orig_shape) @@ -556,8 +558,8 @@ def _test_batched_rubik(hidden_dim: int): B = 4 T = 2 logging.info("in test_batched_rubik") - # device = torch.device('cuda') - device = torch.device("cpu") + device = torch.device('cuda') + #device = torch.device("cpu") dtype = torch.float32 torch.random.manual_seed(42) From a0e8c17988af0d160cfc38a6f8139c669c3f86aa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 May 2026 13:49:35 +0800 Subject: [PATCH 1162/1191] Remove use of invP in computing alpha. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 14 ++++---------- egs/librispeech/ASR/zapformer/rubik.py | 17 ++++------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 78fa277ad4..50894fd60d 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -221,12 +221,9 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x = x / row_denom col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats + eps) - row_denom_sqrt = row_denom.sqrt() - col_denom_sqrt = col_denom.sqrt() - x_half_norm = (x * row_denom_sqrt) / col_denom_sqrt + x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() x = x / col_denom - invP = row_denom * col_denom - return x, x_half_norm, invP + return x, x_half_norm @@ -278,7 +275,7 @@ def cubic_decay_step(group, state, grad): col_stats = state["col_stats"] # we half update the stats here, half update them later. - norm_grad, norm_grad_precon, invP = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -288,10 +285,7 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - # dividing the following by invP means we are using 1 / invP as a scale for computing - # norms, as if we were to compute the norm of delta ~= moving_grad / invP after doing - # moving_grad.add_(prod3 * cubic_alpha). - cubic_alpha = compute_alpha(moving_grad / invP, prod3 / invP, beta1) + cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: (batch_size, 1, 1) moving_grad.add_(prod3 * cubic_alpha) diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index 85ee13aabe..f2ce5d313d 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -88,8 +88,6 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: - - def matrix_shape(shape): """ shape is expected to be a torch.Size with at least two dimensions. @@ -127,12 +125,9 @@ def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): x = x / row_denom col_stats.mul_(beta2).add_(x.abs().mean(dim=0, keepdim=True), alpha=(1 - beta2)) col_denom = (col_stats + eps) - row_denom_sqrt = row_denom.sqrt() - col_denom_sqrt = col_denom.sqrt() - x_half_norm = (x * row_denom_sqrt) / col_denom_sqrt + x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() x = x / col_denom - invP = row_denom * col_denom - return x, x_half_norm, invP + return x, x_half_norm def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): @@ -181,7 +176,7 @@ def cubic_decay_step(group, state, grad): col_stats = state["col_stats"] # we half update the stats here, half update them later. - norm_grad, norm_grad_precon, invP = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) # add the grad to the moving-average grad; the scaling factor used here # doesn't matter as it all gets normalized later. @@ -191,11 +186,7 @@ def cubic_decay_step(group, state, grad): # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. prod3 = scaled_three_way_product(moving_grad) - - # dividing the following by invP means we are using 1 / invP as a scale for computing - # norms, as if we were to compute the norm of delta ~= moving_grad / invP after doing - # moving_grad.add_(prod3 * cubic_alpha). - cubic_alpha = compute_alpha(moving_grad / invP, prod3 / invP, beta1) + cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: scalar moving_grad.add_(prod3 * cubic_alpha) From 655be80908d5e3bc08f06ed9da1fd0608aef30c0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 May 2026 14:50:50 +0800 Subject: [PATCH 1163/1191] Introduce alpha_power = 0.5, making alpha closer to 1. --- .../ASR/zapformer/batched_rubik.py | 46 +++++++++++------ egs/librispeech/ASR/zapformer/rubik.py | 51 ++++++++++++------- 2 files changed, 62 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 50894fd60d..3db7041e2f 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -150,18 +150,25 @@ def scaled_three_way_product(x): def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: """ - Solve the equation: ||x + alpha y||_2^2 == ||beta x||_2^2 - - x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x - alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 - (a,b,c) = (y.y, 2 alpha x.y, x.x) - alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. - # treat the thing inside the sqrt as zero if - # negative, this + Computes the amount of cubic decay to do for each parameter tensor in the batch, as + in effect batch of scalars (one per parameter) of shape (batch_size, 1, 1). + + Solve the equation: ||x - alpha y||_2^2 == ||beta x||_2^2 + + x.x - 2 alpha y.x + alpha^2 y.y = beta^2 x.x + alpha^2 y.y - 2 alpha x.y + (1-beta^2) x.x = 0 + (a,b,c) = (y.y, -2 alpha x.y, x.x) + alpha = (-b - sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + # factoring out 2 from the top and bottom we get: - so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y + so alpha = (x.y - sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y ... we treat the thing inside the sqrt as zero if it is negative, - which gives us the closest real solution + which gives us the closest real solution to zero. + + We then apply a formula that you can see at the bottom, which chooses the + smallest (closest to zero) of two formulae, see the comments. This is basically + heuristic; the safety_factor * min_sum_scale is a safety thing to reduce the + chance of eigenvalues flipping sign. """ eps = 1.0e-40 xx = x.square().mean(dim=(1, 2), keepdim=True) @@ -171,16 +178,23 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: # this alpha is the value that solves exactly for the requested difference in norm. # this will be negative. - alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps + alpha = (xy - (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps - # min_sum_scale is the value of alpha that would minimize the norm of a + alpha y. - min_sum_scale = -xy / yyeps + # min_sum_scale is the value of alpha that would minimize the norm of a - alpha y. + min_sum_scale = xy / yyeps # safety_factor = 0.5 means we are only willing to go halfway to that value that minimizes the norm, # to avoid change of eigenvalue sign / overshoot, which can ultimately lead to certain # parameter eigenvalues getting too large. safety_factor = 0.5 - return torch.maximum(safety_factor * min_sum_scale, alpha) # return the closet to zero of these two formulae. + # alpha_power is a heuristic value that interpolates between the computed alpha, and alpha=(1-beta). + # the intention is that if the singular values are quite peaky (hence alpha << 1), + # we want to make sure that we're doing an adequate amount of decay for the smaller singular values. + alpha_power = 0.5 + + # return the closest to zero of the two formulae below. + return torch.minimum(safety_factor * min_sum_scale, + ((1-beta) ** (1-alpha_power)) * (alpha.clamp(min=1.0e-10) ** alpha_power)) def matrix_shape(shape): @@ -288,7 +302,7 @@ def cubic_decay_step(group, state, grad): cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: (batch_size, 1, 1) - moving_grad.add_(prod3 * cubic_alpha) + moving_grad.add_(prod3 * cubic_alpha, alpha=-1) # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" # will have a smaller variance than the grad itself because of being a mean over independent elements. @@ -334,7 +348,7 @@ def cubic_decay_step(group, state, grad): debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: - cubic_alpha_ratio = -cubic_alpha / (1-beta1) + cubic_alpha_ratio = cubic_alpha / (1-beta1) scale = (assumed_scale / ((delta ** 2).mean(dim=(1, 2), keepdim=True).sqrt() + eps)) logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied], alpha_ratio={cubic_alpha_ratio.flatten()}, delta-max={delta.abs().max(dim=1)[0].max(dim=1)[0]}") diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index f2ce5d313d..f2b70be313 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -53,18 +53,26 @@ def scaled_three_way_product(x): def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: """ - Solve the equation: ||x + alpha y||_2^2 == ||beta x||_2^2 - - x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x - alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 - (a,b,c) = (y.y, 2 alpha x.y, x.x) - alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. - # treat the thing inside the sqrt as zero if - # negative, this + Computes the amount of cubic decay to do for each parameter tensor in the batch, as + a scalar. + + First compute alpha that solves the equation: ||x - alpha y||_2^2 == ||beta x||_2^2 + + + x.x - 2 alpha y.x + alpha^2 y.y = beta^2 x.x + alpha^2 y.y - 2 alpha x.y + (1-beta^2) x.x = 0 + (a,b,c) = (y.y, -2 alpha x.y, x.x) + alpha = (-b - sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + # factoring out 2 from the top and bottom we get: - so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y + so alpha = (x.y - sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y ... we treat the thing inside the sqrt as zero if it is negative, which gives us the closest real solution + + We then apply a formula that you can see at the bottom, which chooses the + smallest (closest to zero) of two formulae, see the comments. This is basically + heuristic; the safety_factor * min_sum_scale is a safety thing to reduce the + chance of eigenvalues flipping sign. """ eps = 1.0e-40 xx = x.square().mean() @@ -74,17 +82,23 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: # this alpha is the value that solves exactly for the requested difference in norm. # this will be negative. - alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps + alpha = (xy - (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps - # min_sum_scale is the value of alpha that would minimize the norm of a + alpha y. - min_sum_scale = -xy / yyeps + # min_sum_scale is the value of alpha that would minimize the norm of a - alpha y. + min_sum_scale = xy / yyeps # safety_factor = 0.5 means we are only willing to go halfway to that value that minimizes the norm, # to avoid change of eigenvalue sign / overshoot, which can ultimately lead to certain # parameter eigenvalues getting too large. safety_factor = 0.5 - return torch.maximum(safety_factor * min_sum_scale, alpha) # return the closet to zero of these two formulae. + # alpha_power is a heuristic value that interpolates between the computed alpha, and alpha=(1-beta). + # the intention is that if the singular values are quite peaky (hence alpha << 1), + # we want to make sure that we're doing an adequate amount of decay for the smaller singular values. + alpha_power = 0.5 + # return the closest to zero of the two formulae below. + return torch.minimum(safety_factor * min_sum_scale, + ((1-beta) ** (1-alpha_power)) * (alpha.clamp(min=1.0e-10) ** alpha_power)) @@ -189,7 +203,7 @@ def cubic_decay_step(group, state, grad): cubic_alpha = compute_alpha(moving_grad, prod3, beta1) # cubic_alpha shape: scalar - moving_grad.add_(prod3 * cubic_alpha) + moving_grad.add_(prod3 * cubic_alpha, alpha=-1) # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" # will have a smaller variance than the grad itself because of being a mean over independent elements. @@ -212,14 +226,13 @@ def cubic_decay_step(group, state, grad): debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) if debug: + cubic_alpha_ratio = cubic_alpha / (1-beta1) scale = (assumed_scale / ((delta ** 2).mean().sqrt() + eps)) - logging.info(f"shape={prod3.shape}, scale={scale.flatten()} [not applied]") - #delta = delta * scale - - ans = -lr * delta + logging.info(f"shape={prod3.shape}, scale={scale} [not applied], alpha_ratio={cubic_alpha_ratio}, delta-max={delta.abs().max()}") - return ans.reshape(orig_shape) + delta.mul_(-lr) + return delta.reshape(orig_shape) From 44f6639561d1b12fc5883863c78441ff93e06ea3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 May 2026 15:09:01 +0800 Subject: [PATCH 1164/1191] Increase alpha_power from 0.5 to 0.75. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- egs/librispeech/ASR/zapformer/rubik.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 3db7041e2f..55114f7043 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -190,7 +190,7 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: # alpha_power is a heuristic value that interpolates between the computed alpha, and alpha=(1-beta). # the intention is that if the singular values are quite peaky (hence alpha << 1), # we want to make sure that we're doing an adequate amount of decay for the smaller singular values. - alpha_power = 0.5 + alpha_power = 0.75 # return the closest to zero of the two formulae below. return torch.minimum(safety_factor * min_sum_scale, diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py index f2b70be313..e2c4a75f8d 100644 --- a/egs/librispeech/ASR/zapformer/rubik.py +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -94,7 +94,7 @@ def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: # alpha_power is a heuristic value that interpolates between the computed alpha, and alpha=(1-beta). # the intention is that if the singular values are quite peaky (hence alpha << 1), # we want to make sure that we're doing an adequate amount of decay for the smaller singular values. - alpha_power = 0.5 + alpha_power = 0.75 # return the closest to zero of the two formulae below. return torch.minimum(safety_factor * min_sum_scale, From 0692a8735d15ba7816008ba4b01b4d5f6b5aed2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 May 2026 13:29:47 +0800 Subject: [PATCH 1165/1191] Fixes suggested by AI from https://github.com/k2-fsa/icefall/pull/2082 --- .../ASR/zapformer/alternating_spec_augment.py | 16 +- .../ASR/zapformer/asr_datamodule.py | 17 +- .../ASR/zapformer/attention_decoder.py | 584 +++++++++++++++++- egs/librispeech/ASR/zapformer/ctc_decode.py | 6 +- egs/librispeech/ASR/zapformer/my_profile.py | 1 - .../ASR/zapformer/streaming_decode.py | 10 +- .../ASR/zapformer/test_subsampling.py | 150 ----- egs/librispeech/ASR/zapformer/train.py | 3 +- .../ASR/zapformer/zapformer_modules.py | 2 +- .../ASR/zipformer/attention_decoder.py | 2 +- icefall/diagnostics.py | 3 +- 11 files changed, 609 insertions(+), 185 deletions(-) mode change 120000 => 100644 egs/librispeech/ASR/zapformer/attention_decoder.py delete mode 100755 egs/librispeech/ASR/zapformer/test_subsampling.py diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 0214f80065..264d72f0b4 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -26,7 +26,7 @@ def __init__( p=0.9, # probability of doing core SpecAug augmentation time_warp_p=0.9, # probability of doing time warping. time_warp_factor=80, # as in original SpecAug paper. - seed=None, # if you leave this as none it will use random.random() + seed=None, # if you leave this as none it will use torch.randint(0, 100000, ()).item() ): super().__init__() assert 0 <= p <= 1 @@ -220,7 +220,6 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ # so definitely this clamping will happen for less than half of the pairs of sequences. padding_tot_rlen = (1. - mask_tot_rlen).clamp(min=0.2) # (batch_size, 1) - eps = 1.0e-20 # get padding lengths by randomly placing dividers on the line of length "padding_tot_rlen" # P is the number of padding regions for each pair of sequences. @@ -263,18 +262,18 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ return mask_starts, mask_ends def state_dict(self, **kwargs) -> Dict[str, Any]: - - dict = { } + state = { } for name in ["max_feature_mask_fraction", "num_feature_masks", "max_frame_mask_fraction", "max_frame_mask_size", "p"]: - dict[name] = getattr(self, name) + state[name] = getattr(self, name) + return state def load_state_dict(self, state_dict: Dict[str, Any]): for name in ["max_feature_mask_fraction", "num_feature_masks", "max_frame_mask_fraction", "max_frame_mask_size", "p"]: if name in state_dict: - setattr(self, name, state_dict["name"]) + setattr(self, name, state_dict[name]) def time_warp_impl(features: torch.Tensor, factor: int, @@ -343,7 +342,7 @@ def time_warp( # Randomly choose whether this transform is applied continue features[sequence_idx] = time_warp_impl( - features[sequence_idx], factor=time_warp_factor + features[sequence_idx], factor=time_warp_factor, generator=generator, ) else: for sequence_idx, num_frames in enumerate(feature_lens): @@ -394,8 +393,7 @@ def _test_alternating_spec_augment(): aspec_augment = lambda x: spec_augment(x, supervision_segments) features = torch.randn(B, T, F, device=device) - lengths = torch.tensor([ features.shape[1] ] * B, dtype=torch.long).to(device=device) - #print("features=", features) + features = aspec_augment(features) frame_is_masked = features[:, :, 0] == features[:, :, -1] diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py index b25653a66a..c4a628df01 100755 --- a/egs/librispeech/ASR/zapformer/asr_datamodule.py +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -18,10 +18,12 @@ import argparse import inspect +import glob import logging from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import re import random # to set its random seed import numpy as np # to set its random seed @@ -34,22 +36,15 @@ PrecomputedFeatures, SimpleCutSampler, ) +import lhotse + # MulticopyDataset is a modified version of K2SpeechRecognitionDataset from # lhotse.dataset, modified to, in training mode, to return a batch that has multiple # different copies of the same data having different Musan # augmentations and the first having none; and also include the key "num_copies" # in the batch which would be 1 for the validation data (no Musan) and 2 for the # different copies of the training data with musan. -try: - from multicopy_dataset import MulticopyDataset # interface like K2SpeechRecognitionDataset -except: - pass - -try: - from icefall.utils import dist_barrier -except: - pass - +from multicopy_dataset import MulticopyDataset # interface like K2SpeechRecognitionDataset from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -304,7 +299,7 @@ def train_dataloaders( num_copies=num_copies, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, + input_transforms=[], return_cuts=self.args.return_cuts, ) diff --git a/egs/librispeech/ASR/zapformer/attention_decoder.py b/egs/librispeech/ASR/zapformer/attention_decoder.py deleted file mode 120000 index 830180a0cd..0000000000 --- a/egs/librispeech/ASR/zapformer/attention_decoder.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zapformer/attention_decoder.py b/egs/librispeech/ASR/zapformer/attention_decoder.py new file mode 100644 index 0000000000..648be4b1e0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/attention_decoder.py @@ -0,0 +1,583 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List, Optional + +import k2 +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from zapformer_utils import penalize_abs_values_gt + +from icefall.utils import add_eos, add_sos, make_pad_mask + + +class AttentionDecoderModel(nn.Module): + """ + Args: + vocab_size (int): Number of classes. + decoder_dim: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + num_heads (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + sos_id: int = 1, + eos_id: int = 1, + dropout: float = 0.1, + ignore_id: int = -1, + label_smoothing: float = 0.1, + ): + super().__init__() + self.eos_id = eos_id + self.sos_id = sos_id + self.ignore_id = ignore_id + + # For the segment of the warmup period, we let the Embedding + # layer learn something. Then we start to warm up the other encoders. + self.decoder = TransformerDecoder( + vocab_size=vocab_size, + d_model=decoder_dim, + num_decoder_layers=num_decoder_layers, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + + # Used to calculate attention-decoder loss + self.loss_fun = LabelSmoothingLoss( + ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum" + ) + + def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor): + """Prepare ys_in_pad and ys_out_pad.""" + ys_in = add_sos(ys, sos_id=self.sos_id) + # [B, S+1], start with SOS + ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id) + ys_in_lens = ys_lens + 1 + + ys_out = add_eos(ys, eos_id=self.eos_id) + # [B, S+1], end with EOS + ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id) + + return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64) + + def calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys: k2.RaggedTensor, + ys_lens: torch.Tensor, + ) -> torch.Tensor: + """Calculate attention-decoder loss. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: The attention-decoder loss. + """ + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + loss = self.loss_fun(x=decoder_out, target=ys_out_pad) + return loss + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + token_ids: List[List[int]], + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from attention-decoder. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: A tensor of shape (batch, num_tokens). + """ + ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device) + row_splits = ys.shape.row_splits(1) + ys_lens = row_splits[1:] - row_splits[:-1] + + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + batch_size, _, num_classes = decoder_out.size() + nll = nn.functional.cross_entropy( + decoder_out.view(-1, num_classes), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + return nll + + +class TransformerDecoder(nn.Module): + """Transfomer decoder module. + + Args: + vocab_size: output dim + d_model: decoder dimension + num_decoder_layers: number of decoder layers + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) + + # Absolute positional encoding + self.pos = PositionalEncoding(d_model, dropout_rate=0.1) + + self.num_layers = num_decoder_layers + self.layers = nn.ModuleList( + [ + DecoderLayer( + d_model=d_model, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + for _ in range(num_decoder_layers) + ] + ) + + self.output_layer = nn.Linear(d_model, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + memory: Optional[torch.Tensor] = None, + memory_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, tgt_len). + x_lens: A tensor of shape (batch,) containing the number of tokens in `x` + before padding. + memory: + Memory sequence of shape (batch, src_len, memory_dim). + memory_lens: + A tensor of shape (batch,) containing the number of frames in + `memory` before padding. + + Returns: + Decoded token logits before softmax (batch, tgt_len, vocab_size) + """ + x = self.embed(x) # (batch, tgt_len, embed_dim) + x = self.pos(x) # (batch, tgt_len, embed_dim) + + x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim) + + # construct attn_mask for self-attn modules + padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) + causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) + + if memory is not None: + memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) + # construct memory_attn_mask for cross-attn modules + memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len) + memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len) + else: + memory_attn_mask = None + + for i, mod in enumerate(self.layers): + x = mod( + x, + attn_mask=attn_mask, + memory=memory, + memory_attn_mask=memory_attn_mask, + ) + + x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size) + x = self.output_layer(x) + + return x + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + d_model: equal to decoder_dim, total dimension of the decoder + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + d_model: int = 512, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + + self.norm_self_attn = nn.LayerNorm(d_model) + self.self_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, dropout=0.0 + ) + + self.norm_src_attn = nn.LayerNorm(d_model) + self.src_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0 + ) + + self.norm_ff = nn.LayerNorm(d_model) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, feedforward_dim), + Swish(), + nn.Dropout(dropout), + nn.Linear(feedforward_dim, d_model), + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input sequence of shape (seq_len, batch, embed_dim). + attn_mask: A binary mask for self-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + memory: Memory sequence of shape (seq_len, batch, memory_dim). + memory_attn_mask: A binary mask for cross-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + """ + # self-attn module + qkv = self.norm_self_attn(x) + self_attn_out = self.self_attn( + query=qkv, key=qkv, value=qkv, attn_mask=attn_mask + ) + x = x + self.dropout(self_attn_out) + + # cross-attn module + q = self.norm_src_attn(x) + src_attn_out = self.src_attn( + query=q, key=memory, value=memory, attn_mask=memory_attn_mask + ) + x = x + self.dropout(src_attn_out) + + # feed-forward module + x = x + self.dropout(self.feed_forward(self.norm_ff(x))) + + return x + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, but must be a multiple of num_heads. + num_heads: number of parallel attention heads. + memory_dim: dimension of memory embedding, optional. + dropout: a Dropout layer on attn_output_weights. + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + memory_dim: Optional[int] = None, + dropout: float = 0.0, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.head_dim = attention_dim // num_heads + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + self.linear_v = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + + self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute dot product attention. + + Args: + query: Query tensor of shape (tgt_len, batch, embed_dim). + key: Key tensor of shape (src_len, batch, embed_dim or memory_dim). + value: Value tensor of shape (src_len, batch, embed_dim or memory_dim). + key_padding_mask: A binary mask indicating which elements are padding. + Its shape is (batch, src_len). + attn_mask: A binary mask indicating which elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + + Returns: + Output tensor of shape (tgt_len, batch, embed_dim). + """ + num_heads = self.num_heads + head_dim = self.head_dim + + tgt_len, batch, _ = query.shape + src_len = key.shape[0] + + q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim) + k = self.linear_k(key) # (src_len, batch, num_heads * head_dim) + v = self.linear_v(value) # (src_len, batch, num_heads * head_dim) + + q = q.reshape(tgt_len, batch, num_heads, head_dim) + q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim) + k = k.reshape(src_len, batch, num_heads, head_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len) + v = v.reshape(src_len, batch, num_heads, head_dim) + v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1) + + # Note: could remove the scaling operation when using ScaledAdam + # (batch, head, tgt_len, src_len) + attn_weights = torch.matmul(q, k) / math.sqrt(head_dim) + + # From zipformer.py: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + + if attn_mask is not None: + assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == ( + batch, + tgt_len, + src_len, + ), attn_mask.shape + attn_weights = attn_weights.masked_fill( + attn_mask.unsqueeze(1), float("-inf") + ) + + attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + # (batch * head, tgt_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + assert attn_output.shape == ( + batch * num_heads, + tgt_len, + head_dim, + ), attn_output.shape + + attn_output = attn_output.transpose(0, 1).contiguous() + attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) + + # (batch, tgt_len, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def subsequent_mask(size, device="cpu", dtype=torch.bool): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def _test_attention_decoder_model(): + m = AttentionDecoderModel( + vocab_size=500, + decoder_dim=512, + num_decoder_layers=6, + attention_dim=512, + num_heads=8, + feedforward_dim=2048, + memory_dim=384, + dropout=0.1, + sos_id=1, + eos_id=1, + ignore_id=-1, + ) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + m.eval() + encoder_out = torch.randn(2, 50, 384) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] + + nll = m.nll(encoder_out, encoder_out_lens, token_ids) + print(nll) + + +if __name__ == "__main__": + _test_attention_decoder_model() diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py index 963e4f2047..dd1ec0c7e0 100755 --- a/egs/librispeech/ASR/zapformer/ctc_decode.py +++ b/egs/librispeech/ASR/zapformer/ctc_decode.py @@ -348,7 +348,7 @@ def get_parser(): path of the n paths is the decoding result. - (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with the given beam, rescore them with the attention decoder. - - (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during + - (12) ctc-prefix-beam-search-shallow-fusion. Use NNLM shallow fusion during beam search, LODR and hotwords are also supported in this decoding method. """, ) @@ -387,7 +387,7 @@ def get_parser(): "--nnlm-scale", type=float, default=0, - help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fusion. Used only when `--use-shallow-fusion` is set to True. """, ) @@ -600,7 +600,7 @@ def decode_one_batch( ans[a_scale_str] = hyps return ans - if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.decoding_method in [ "ctc-prefix-beam-search-shallow-fussion", "ctc-prefix-beam-search-shallow-fusion" ]: token_ids = ctc_prefix_beam_search_shallow_fussion( ctc_output=ctc_output, encoder_out_lens=encoder_out_lens, diff --git a/egs/librispeech/ASR/zapformer/my_profile.py b/egs/librispeech/ASR/zapformer/my_profile.py index 333b139689..458a759694 100755 --- a/egs/librispeech/ASR/zapformer/my_profile.py +++ b/egs/librispeech/ASR/zapformer/my_profile.py @@ -34,7 +34,6 @@ get_joiner_model, get_params, ) -from zapformer import BypassModule from icefall.profiler import get_model_profile from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py index c1437297f3..b7d2ed7af0 100755 --- a/egs/librispeech/ASR/zapformer/streaming_decode.py +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -215,7 +215,7 @@ def get_parser(): type=str2bool, default=False, help="""If True, decode commonvoice in addition to librispeech test sets.""", - ) + ) add_model_arguments(parser) @@ -435,7 +435,7 @@ def streaming_forward( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] - + ( encoder_out, encoder_out_lens, @@ -494,7 +494,7 @@ def decode_one_chunk( # Make sure the length after encoder_embed is at least 1. # The encoder_embed subsample features (T - 7) // 2 - tail_length = chunk_size * 2 + 7 + tail_length = chunk_size * 2 + 7 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -830,7 +830,7 @@ def main(): start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): - if start >= 0: + if i >= 1: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) @@ -913,7 +913,7 @@ def main(): giga_dev_cuts = gigaspeech.dev_cuts() test_sets += ["giga-dev", "giga-test"] test_cuts += [giga_dev_cuts, giga_test_cuts] - + if args.cv: commonvoice = CommonVoice(args.manifest_dir) cv_test_cuts = commonvoice.test_cuts() diff --git a/egs/librispeech/ASR/zapformer/test_subsampling.py b/egs/librispeech/ASR/zapformer/test_subsampling.py deleted file mode 100755 index b502d5a773..0000000000 --- a/egs/librispeech/ASR/zapformer/test_subsampling.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from subsampling import Conv2dSubsampling - -# TODO: fix, this does not work right tnow -def test_conv2d_subsampling(): - layer1_channels = 8 - layer2_channels = 32 - layer3_channels = 128 - - out_channels = 192 - encoder_embed = Conv2dSubsampling( - in_channels=80, - out_channels=out_channels, - layer1_channels=layer1_channels, - layer2_channels=layer2_channels, - layer3_channels=layer3_channels, - ) - N = 2 - T = 200 - num_features = 80 - x = torch.rand(N, T, num_features) - x_copy = x.clone() - - x = x.unsqueeze(1) # (N, 1, T, num_features) - - x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1) - assert x.shape == (N, layer1_channels, T - 2, num_features) - # (2, 8, 198, 80) - - x = encoder_embed.conv[1](x) # scale grad - x = encoder_embed.conv[2](x) # balancer - x = encoder_embed.conv[3](x) # swooshR - - x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2 - assert x.shape == ( - N, - layer2_channels, - ((T - 2) - 3) // 2 + 1, - (num_features - 3) // 2 + 1, - ) - # (2, 32, 98, 39) - - x = encoder_embed.conv[5](x) # balancer - x = encoder_embed.conv[6](x) # swooshR - - # conv2d: - # in 32, out 128, kernel 3, stride (1, 2) - x = encoder_embed.conv[7](x) - assert x.shape == ( - N, - layer3_channels, - (((T - 2) - 3) // 2 + 1) - 2, - (((num_features - 3) // 2 + 1) - 3) // 2 + 1, - ) - # (2, 128, 96, 19) - - x = encoder_embed.conv[8](x) # balancer - x = encoder_embed.conv[9](x) # swooshR - - # (((T - 2) - 3) // 2 + 1) - 2 - # = (T - 2) - 3) // 2 + 1 - 2 - # = ((T - 2) - 3) // 2 - 1 - # = (T - 2 - 3) // 2 - 1 - # = (T - 5) // 2 - 1 - # = (T - 7) // 2 - assert x.shape[2] == (x_copy.shape[1] - 7) // 2 - - # (((num_features - 3) // 2 + 1) - 3) // 2 + 1, - # = ((num_features - 3) // 2 + 1 - 3) // 2 + 1, - # = ((num_features - 3) // 2 - 2) // 2 + 1, - # = (num_features - 3 - 4) // 2 // 2 + 1, - # = (num_features - 7) // 2 // 2 + 1, - # = (num_features - 7) // 4 + 1, - # = (num_features - 3) // 4 - assert x.shape[3] == (x_copy.shape[2] - 3) // 4 - - assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) - - # Input shape to convnext is - # - # (N, layer3_channels, (T-7)//2, (num_features - 3)//4) - - # conv2d: in layer3_channels, out layer3_channels, groups layer3_channels - # kernel_size 7, padding 3 - x = encoder_embed.convnext.depthwise_conv(x) - assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) - - # conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1 - x = encoder_embed.convnext.pointwise_conv1(x) - assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4) - - x = encoder_embed.convnext.hidden_balancer(x) # balancer - x = encoder_embed.convnext.activation(x) # swooshL - - # conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1 - x = encoder_embed.convnext.pointwise_conv2(x) - assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) - - # bypass and layer drop, omitted here. - x = encoder_embed.convnext.out_balancer(x) - - # Note: the input and output shape of ConvNeXt are the same - - x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1) - assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4)) - - x = encoder_embed.out(x) - assert x.shape == (N, (T - 7) // 2, out_channels) - - x = encoder_embed.out_whiten(x) - x = encoder_embed.out_norm(x) - # final layer is dropout - - # test streaming forward - - subsampling_factor = 2 - cached_left_padding = encoder_embed.get_init_states(batch_size=N) - depthwise_conv_kernel_size = 7 - pad_size = (depthwise_conv_kernel_size - 1) // 2 - - assert cached_left_padding.shape == ( - N, - layer3_channels, - pad_size, - (num_features - 3) // 4, - ) - - chunk_size = 16 - right_padding = pad_size * subsampling_factor - T = chunk_size * subsampling_factor + 7 + right_padding - x = torch.rand(N, T, num_features) - x_lens = torch.tensor([T] * N) - y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward( - x, x_lens, cached_left_padding - ) - - assert y.shape == (N, chunk_size, out_channels), y.shape - assert next_cached_left_padding.shape == cached_left_padding.shape - - assert y.shape[1] == y_lens[0] == y_lens[1] - - -def main(): - test_conv2d_subsampling() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index a6551a587c..3d21e76be9 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -86,7 +86,7 @@ from variable_combined_scheduler import VariableCombinedLRScheduler try: from variable_combined_scheduler import InterpCosineLRScheduler - LRSchedulerType = VariableCombinedLRSchedule + LRSchedulerType = VariableCombinedLRScheduler except: pass @@ -807,6 +807,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 decoder = AttentionDecoderModel( vocab_size=params.vocab_size, decoder_dim=lookup(params, "attention_decoder_dim"), diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index adad0fc5bc..0fcbe1381e 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -154,7 +154,7 @@ def forward( @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 648be4b1e0..bff536f90b 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -23,7 +23,7 @@ import torch import torch.nn as nn from label_smoothing import LabelSmoothingLoss -from zapformer_utils import penalize_abs_values_gt +from scaling import penalize_abs_values_gt from icefall.utils import add_eos, add_sos, make_pad_mask diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 4e84a76082..bdf3e02dcc 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -361,9 +361,8 @@ def __init__(self, opts: Optional[TensorDiagnosticOptions] = None): self.diagnostics = dict() def __getitem__(self, name: str): - T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic if name not in self.diagnostics: - self.diagnostics[name] = T(self.opts, name) + self.diagnostics[name] = TensorDiagnostic(self.opts, name) return self.diagnostics[name] def print_diagnostics(self) -> dict: From cae74c4c8e375ff88a3be9384db615b2c33ac9b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 May 2026 21:57:03 +0800 Subject: [PATCH 1166/1191] Further fix --- egs/librispeech/ASR/zapformer/model.py | 2 +- egs/librispeech/ASR/zapformer/streaming_beam_search.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 61ac25e3d3..42f7be3a82 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -322,7 +322,7 @@ def forward( am_scale: float = 0.0, lm_scale: float = 0.0, aux_loss_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: diff --git a/egs/librispeech/ASR/zapformer/streaming_beam_search.py b/egs/librispeech/ASR/zapformer/streaming_beam_search.py index 3c8565b330..d5a475627a 100644 --- a/egs/librispeech/ASR/zapformer/streaming_beam_search.py +++ b/egs/librispeech/ASR/zapformer/streaming_beam_search.py @@ -48,7 +48,7 @@ def greedy_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device T = encoder_out.size(1) decoder_input = torch.tensor( From ba4720118f559861b5ede201fd1ead06ac1e44a4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 May 2026 11:18:33 +0800 Subject: [PATCH 1167/1191] Fix loading state dict dtype issue --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index cf4cdc8851..048ef3b9a0 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -346,6 +346,10 @@ def muon_core_step(group, state, grad): momentum_buffer = state["momentum_buffer"] second_momentum_buffer = state["second_momentum_buffer"] + if momentum_buffer.dtype == torch.float: # Error due to loading state dict; TODO put this in load_state_dict() + momentum_buffer = momentum_buffer.to(COMPUTE_DTYPE) + state["momentum_buffer"] = momentum_buffer + def t(x): return torch.tensor(x, device=grad.device, dtype=COMPUTE_DTYPE) From 909cd2d4350a9ee45e60133915ce38a8d6ebe123 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 May 2026 11:24:21 +0800 Subject: [PATCH 1168/1191] Reduce beta1 in BatchedRubik[muon-core] from .99 to .98. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index a6551a587c..d3708aef4d 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), lr=params.base_lr, - beta1=0.99, + beta1=0.98, ) if True: From f2a811e67dc55ef8599fe76de470c8915a35b6db Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 May 2026 16:46:38 +0800 Subject: [PATCH 1169/1191] Changed muon-core update to have symmetric row-or-col normalization both before and after the orthogonalization. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 048ef3b9a0..f3eabdb8e5 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -292,6 +292,10 @@ def muon_step_fused( momentum_buffer.lerp_(stacked_grads, 1 - momentum) g = stacked_grads.lerp_(momentum_buffer, momentum) + # apply the same normalization both before and after + # the core muon step, the symmetry ensures it is a descent direction. + g = g / (second_momentum_buffer.sqrt() + eps).to(g.dtype) + # Polar express # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g @@ -338,9 +342,9 @@ def muon_core_step(group, state, grad): assert step < 2 state["momentum_buffer"] = torch.zeros(batch_size, rows, cols, device=grad.device, dtype=COMPUTE_DTYPE) if rows > cols: - state["second_momentum_buffer"] = torch.zeros(batch_size, rows, 1, device=grad.device, dtype=torch.float) + state["second_momentum_buffer"] = torch.ones(batch_size, rows, 1, device=grad.device, dtype=torch.float) else: - state["second_momentum_buffer"] = torch.zeros(batch_size, 1, cols, device=grad.device, dtype=torch.float) + state["second_momentum_buffer"] = torch.ones(batch_size, 1, cols, device=grad.device, dtype=torch.float) momentum_buffer = state["momentum_buffer"] From 2ef9d39d63309a50243478e6a903fb563fc8d1a3 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 25 May 2026 16:56:11 +0800 Subject: [PATCH 1170/1191] copy export-onnx.py from zipformer --- egs/librispeech/ASR/zapformer/export-onnx.py | 648 +++++++++++++++++++ 1 file changed, 648 insertions(+) create mode 100755 egs/librispeech/ASR/zapformer/export-onnx.py diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py new file mode 100755 index 0000000000..03c7d6f820 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -0,0 +1,648 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" \ + --fp16 True +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + if params.fp16: + logging.info("Generate fp16 models") + + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) + + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_filename_fp16) + + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_filename_fp16) + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() From 27aea4e6c114e7dfc42a517d3871b470fd6f75b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 May 2026 16:57:40 +0800 Subject: [PATCH 1171/1191] Revert beta1 from 0.98 to 0.99. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index d3708aef4d..a6551a587c 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -1435,7 +1435,7 @@ def run(rank, world_size, args): optimizer = BatchedRubik( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), lr=params.base_lr, - beta1=0.98, + beta1=0.99, ) if True: From c3fc7d6c75e95f463cd43db3076989c109921b30 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 28 May 2026 10:20:46 +0800 Subject: [PATCH 1172/1191] fix to onnx exporting, not working yet --- egs/librispeech/ASR/zapformer/export-onnx.py | 81 +++++++++---------- egs/librispeech/ASR/zapformer/model.py | 2 - .../ASR/zapformer/scaling_converter.py | 53 ++++++++++++ egs/librispeech/ASR/zapformer/subsampling.py | 1 + egs/librispeech/ASR/zapformer/zapformer.py | 18 +++-- .../ASR/zapformer/zapformer_modules.py | 29 ++++--- 6 files changed, 122 insertions(+), 62 deletions(-) create mode 100644 egs/librispeech/ASR/zapformer/scaling_converter.py diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py index 03c7d6f820..94a00aa9c3 100755 --- a/egs/librispeech/ASR/zapformer/export-onnx.py +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -6,28 +6,11 @@ """ This script exports a transducer model from PyTorch to ONNX. -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model +Usage: cd egs/librispeech/ASR -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ +./zapformer/export-onnx.py \ --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ @@ -73,7 +56,7 @@ from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 +from zapformer import Zapformer from icefall.checkpoint import ( average_checkpoints, @@ -131,7 +114,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="zapformer/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, @@ -191,15 +174,15 @@ def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" + """A wrapper for Zapformer and the encoder_proj from the joiner""" def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + self, encoder: Zapformer, encoder_embed: nn.Module, encoder_proj: nn.Linear ): """ Args: encoder: - A Zipformer encoder. + A Zapformer encoder. encoder_proj: The projection layer for encoder from the joiner. """ @@ -213,7 +196,7 @@ def forward( x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward + """Please see the help information of Zapformer.forward Args: x: @@ -313,29 +296,43 @@ def export_encoder_model_onnx( x = torch.zeros(1, 100, 80, dtype=torch.float32) x_lens = torch.tensor([100], dtype=torch.int64) + # Pre-compute angular frequency bases so tracing uses cached values + # instead of recomputing with varying constants per layer. + encoder_model.encoder.warmup_angular_freq_bases( + seq_len=100, left_context_len=0, device=x.device + ) + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) + import traceback + + try: + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=True, + enable_onnx_checker=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + except Exception as e: + logging.error(f"Failed to export the encoder model to ONNX: {e}") + logging.error(traceback.format_exc()) + raise e meta_data = { - "model_type": "zipformer2", + "model_type": "zapformer", "version": "1", "model_author": "k2-fsa", - "comment": "non-streaming zipformer2", + "comment": "non-streaming zapformer", } logging.info(f"meta_data: {meta_data}") diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py index 42f7be3a82..f64528b8ef 100755 --- a/egs/librispeech/ASR/zapformer/model.py +++ b/egs/librispeech/ASR/zapformer/model.py @@ -94,8 +94,6 @@ def __init__( assert hasattr(decoder, "blank_id") assert joiner is not None - - self.decoder = decoder self.joiner = joiner diff --git a/egs/librispeech/ASR/zapformer/scaling_converter.py b/egs/librispeech/ASR/zapformer/scaling_converter.py new file mode 100644 index 0000000000..aaa3a8c1f9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/scaling_converter.py @@ -0,0 +1,53 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file provides a convert_scaled_to_non_scaled() function for zapformer. + +Unlike zipformer, zapformer's training-only modules (ScaleLimiter, +CorrelationLimiter, ActivationAndLinear, etc.) already handle ONNX tracing +internally via torch.jit.is_tracing() checks, so no module replacement is +needed at export time. This function is provided for API compatibility. +""" + +import torch.nn as nn + + +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, + is_pnnx: bool = False, + is_onnx: bool = False, +): + """ + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. + is_onnx: + True if we are going to export the model for ONNX. + Return: + Return the model unchanged. + + Note: zapformer modules already return identity/zero during torch.jit + tracing, so no conversion is necessary. + """ + return model diff --git a/egs/librispeech/ASR/zapformer/subsampling.py b/egs/librispeech/ASR/zapformer/subsampling.py index 45c97d468f..c4cb90ea61 100644 --- a/egs/librispeech/ASR/zapformer/subsampling.py +++ b/egs/librispeech/ASR/zapformer/subsampling.py @@ -62,6 +62,7 @@ def __init__( if not causal: padding = (kernel_size[0] // 2, kernel_size[1] // 2) + self.left_pad = 0 else: padding = (0, kernel_size[1] // 2) self.left_pad = kernel_size[0] - 1 diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index 5e4587ed84..bde0b1ec88 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -344,6 +344,12 @@ def compute_projection_overlap(self, verbose: bool = False): logging.info(f"overlap[{i}, {j}] = {cosine}") return tot_loss + def warmup_angular_freq_bases(self, seq_len: int, left_context_len: int, device: torch.device): + """Pre-compute angular frequency bases for all encoder layers. + Call this before torch.jit.trace to avoid tracer issues.""" + for module in self.encoders: + for layer in module.layers: + layer.self_attn.rel_pos.angular_freq_basis(seq_len, left_context_len, device) def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -1378,9 +1384,6 @@ def backward( return attn_scores_grad + attn_scores.grad, None, None, None, None - - - class FeedforwardModule(nn.Module): """Feedforward module in Zapformer model.""" @@ -1553,7 +1556,13 @@ def forward(self, seq_len: int, left_context_len: int, device: torch.device) -> end = start + 2 * seq_len + left_context_len - 1 return self._cached_basis[start:end] - t = torch.arange(-(seq_len + left_context_len - 1), seq_len, device=device) + if torch.jit.is_tracing(): + raise RuntimeError( + "AngularFreqBasis: cache miss during tracing. " + "Call warmup_angular_freq_bases() before tracing." + ) + + t = torch.arange(-(seq_len + left_context_len - 1), seq_len, dtype=torch.double, device=device) basis = compute_angular_freq_basis_triangular(self.freqs, t, scale=False) # basis: (2 * seq_len + left_context_len - 1, num_freqs, 2) basis = basis.permute(0, 2, 1) @@ -1969,7 +1978,6 @@ def forward(self, - class ConvolutionModule(nn.Module): """ConvolutionModule in Zapformer model. diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index 0fcbe1381e..2558087562 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -21,7 +21,6 @@ import random from typing import Optional, Tuple, Union, Any -import k2 import torch import torch.nn as nn from torch import Tensor @@ -29,8 +28,6 @@ from zapformer_utils import limit_param_value - - def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) @@ -62,7 +59,6 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: return torch.logaddexp(x, y) - # all arg tensors except x are scalars. def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): stats = (x ** 2).mean(dim=2, keepdim=True) @@ -149,7 +145,6 @@ def forward( ballast_frames: Tensor, ) -> Tensor: ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) - return _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) @@ -157,7 +152,6 @@ def forward( def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors - with torch.amp.autocast('cuda', enabled=False): x = x.to(torch.float32).detach().requires_grad_() offset = offset.to(torch.float32).detach().requires_grad_() @@ -179,6 +173,7 @@ def c(x): return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad) + class SequenceNormFunction(torch.autograd.Function): @staticmethod def forward( @@ -190,7 +185,6 @@ def forward( ) -> Tensor: ctx.save_for_backward(x, offset, scale) ctx.mask = mask - return _sequence_norm(x, offset, scale, mask) @@ -336,6 +330,12 @@ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: return ans + @torch.jit.export + def get_init_cache(self, batch_size: int): + """Get initial cache for streaming inference.""" + cached_stats_sum = torch.zeros(batch_size) + cached_len = torch.zeros(batch_size) + return cached_stats_sum, cached_len # assume layout: (time, batch, channel) @@ -727,8 +727,6 @@ def forward(self, x): return _no_op(x) - - def torch_compile(fn, *args, **kwargs): if hasattr(torch, 'compile'): fn = torch.compile(fn, *args, **kwargs, dynamic=True, options={"shape_padding": True, "force_shape_pad": True}) @@ -771,6 +769,8 @@ def __init__(self): def forward(self, x: Tensor) -> Tensor: """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 on the input and 0.25 on the output..""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return swashl(x) return self.func(x) class SwashR(torch.nn.Module): @@ -780,10 +780,11 @@ def __init__(self): def forward(self, x: Tensor) -> Tensor: """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 on the input and 0.25 on the output..""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return swashr(x) return self.func(x) - class ActivationAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd @@ -823,7 +824,6 @@ def backward(ctx, ans_grad: Tensor): return x_deriv, weight_deriv, bias_deriv, None, None - class ActivationAndLinear(torch.nn.Module): """ This merges an activation function followed by a nn.Linear module; @@ -875,7 +875,11 @@ def __init__( def forward(self, x: Tensor): if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - x = self.forward_func(x) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + func = swashl if self.activation == "SwashL" else swashr + else: + func = self.forward_func + x = func(x) return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationAndLinearFunction.apply( @@ -887,7 +891,6 @@ def forward(self, x: Tensor): ) - def _test_swashl_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True From b638c5ae99f58663683b8000d57e67419f6d9c15 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 28 May 2026 12:02:09 +0800 Subject: [PATCH 1173/1191] export transducer works --- egs/librispeech/ASR/zapformer/export-onnx.py | 5 +- egs/librispeech/ASR/zapformer/export.py | 526 ++++++++++++++++++ egs/librispeech/ASR/zapformer/onnx_check.py | 121 +++- egs/librispeech/ASR/zapformer/zapformer.py | 14 +- .../ASR/zapformer/zapformer_modules.py | 16 +- .../ASR/zapformer/zapformer_utils.py | 9 +- 6 files changed, 669 insertions(+), 22 deletions(-) create mode 100755 egs/librispeech/ASR/zapformer/export.py diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py index 94a00aa9c3..f56297cb6f 100755 --- a/egs/librispeech/ASR/zapformer/export-onnx.py +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -265,7 +265,8 @@ def forward( Return a 2-D tensor of shape (N, vocab_size) """ logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) + # see comment in joiner.py for the scale of 2.0 + logit = 2.0 * self.output_linear(torch.tanh(logit)) return logit @@ -302,8 +303,6 @@ def export_encoder_model_onnx( seq_len=100, left_context_len=0, device=x.device ) - encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) - import traceback try: diff --git a/egs/librispeech/ASR/zapformer/export.py b/egs/librispeech/ASR/zapformer/export.py new file mode 100755 index 0000000000..bf7ee65208 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zapformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zapformer/decode.py \ + --exp-dir ./zapformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zapformer/decode.py` and `zapformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zapformer/decode.py \ + --exp-dir ./zapformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zapformer/streaming_decode.py \ + --exp-dir ./zapformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed (streaming)""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors. + states[:-2] are the encoder caches (9 tensors per layer). + states[-2] is the cached left padding for ConvNeXt module. + states[-1] is processed_lens of shape (batch,). + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed = states[-2] + x, x_lens, new_cached_embed = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cache=cached_embed, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_caches = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_caches, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + caches=encoder_caches, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_caches + [ + new_cached_embed, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, + states[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs). + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_caches(batch_size, device) + + embed_cache = self.encoder_embed.get_init_cache(batch_size, device) + states.append(embed_cache) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_check.py b/egs/librispeech/ASR/zapformer/onnx_check.py index c248ea6487..f57886c2b3 100755 --- a/egs/librispeech/ASR/zapformer/onnx_check.py +++ b/egs/librispeech/ASR/zapformer/onnx_check.py @@ -79,8 +79,11 @@ import argparse import logging +from typing import Tuple +import onnxruntime as ort + import torch -from onnx_pretrained import OnnxModel +# from onnx_pretrained import OnnxModel def get_parser(): @@ -119,6 +122,122 @@ def get_parser(): return parser +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + def test_encoder( torch_model: torch.jit.ScriptModule, onnx_model: OnnxModel, diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index bde0b1ec88..a726c8b7ec 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -674,8 +674,7 @@ def forward( src_orig = src src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), - 2. * aux_loss_scale, mask=src_key_padding_mask), - None) + 2. * aux_loss_scale, mask=src_key_padding_mask)) src_pre_ff1 = src @@ -855,7 +854,7 @@ def forward( attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, aux_loss_scale: float = 0.0, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: @@ -870,8 +869,7 @@ def forward( masked position. May be None. Returns: - (out, out_sd), both of the same shape as src, - where out_sd is an alternative version of out for stochastic-depth, that does not see the bypass. + out, of the same shape as src. """ src_orig_fulldim = src @@ -1178,7 +1176,7 @@ def forward( g = vg[..., N:] if self.training: # don't let the sigmoid values get too extreme, limit to -2..2. - g = penalize_abs_values_gt(g, 2, penalty=0.02*aux_loss_scale) + g = penalize_abs_values_gt(g, 2.0, penalty=0.02*aux_loss_scale) g_in, g_out = g.chunk(2, dim=-1) v = v * self.sigmoid_in(g_in) @@ -1535,7 +1533,7 @@ def __init__(self, num_freqs: int, low_freq_factor: float = 0.001): freqs[0] = 0.0 # in case of roundoff self.register_buffer('freqs', freqs, persistent=False) - self._cached_basis: Optional[Tensor] = None + self._cached_basis: Tensor = torch.empty(0) self._cached_seq_len: int = -1 self._cached_left_context_len: int = -1 @@ -1549,7 +1547,7 @@ def forward(self, seq_len: int, left_context_len: int, device: torch.device) -> """ S = self._cached_seq_len L = self._cached_left_context_len - if (self._cached_basis is not None + if (self._cached_basis.numel() > 0 and seq_len <= S and seq_len + left_context_len <= S + L): start = S + L - seq_len - left_context_len diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py index 2558087562..aaf31378b8 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_modules.py +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -65,12 +65,12 @@ def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tens T = x.shape[0] # time if mask is None: stats = stats.sum(dim=0) - lengths = T + lengths = torch.tensor(T, dtype=stats.dtype, device=stats.device) else: - mask = (~mask).to(torch.float).t().unsqueeze(-1) - stats = stats * mask + mask_f = (~mask).to(torch.float).t().unsqueeze(-1) + stats = stats * mask_f stats = stats.sum(dim=0) - lengths = mask.sum(dim=0) + lengths = mask_f.sum(dim=0) scales = (lengths / stats).sqrt() assert scales.shape == (x.shape[1], 1) @@ -876,10 +876,12 @@ def __init__( def forward(self, x: Tensor): if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing(): - func = swashl if self.activation == "SwashL" else swashr + if self.activation == "SwashL": + x = swashl(x) + else: + x = swashr(x) else: - func = self.forward_func - x = func(x) + x = self.forward_func(x) return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationAndLinearFunction.apply( diff --git a/egs/librispeech/ASR/zapformer/zapformer_utils.py b/egs/librispeech/ASR/zapformer/zapformer_utils.py index 0470b74690..e7db94b884 100644 --- a/egs/librispeech/ASR/zapformer/zapformer_utils.py +++ b/egs/librispeech/ASR/zapformer/zapformer_utils.py @@ -66,7 +66,7 @@ def softmax(x: Tensor, dim: int): def penalize_abs_values_gt( - x: Tensor, limit: float, penalty: float, name: str = None + x: Tensor, limit: float, penalty: float, name: str = "" ) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of @@ -118,8 +118,9 @@ def backward(ctx, ans_grad: Tensor): ) -def with_loss(x, y, name=None): - # returns x but adds y.sum() to the loss function. +def with_loss(x: Tensor, y: Tensor, name: str = "") -> Tensor: + if torch.jit.is_scripting(): + return x return WithLoss.apply(x, y, name) @@ -152,6 +153,8 @@ def limit_param_value( # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. + if torch.jit.is_scripting(): + return x if training: return LimitParamValue.apply(x, min, max) else: From 6d0080c4cf9a14225890713093c7a7352fffc011 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 28 May 2026 13:41:14 +0800 Subject: [PATCH 1174/1191] fix onnx inference --- egs/librispeech/ASR/zapformer/export-onnx.py | 15 +- .../ASR/zapformer/jit_pretrained.py | 281 ++++++++++++++++++ .../ASR/zapformer/onnx_pretrained.py | 22 +- 3 files changed, 304 insertions(+), 14 deletions(-) create mode 100755 egs/librispeech/ASR/zapformer/jit_pretrained.py diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py index f56297cb6f..db88bc1c0a 100755 --- a/egs/librispeech/ASR/zapformer/export-onnx.py +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -294,13 +294,22 @@ def export_encoder_model_onnx( opset_version: The opset version to use. """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) + # Use a large dummy input (3000 fbank frames ≈ 30s audio) so that the + # relative position basis baked into the ONNX graph is large enough for + # any input encountered at inference time. The basis becomes a constant + # in the ONNX graph, and the GatherElements indices at runtime must not + # exceed its size. + T_max = 3000 + x = torch.zeros(1, T_max, 80, dtype=torch.float32) + x_lens = torch.tensor([T_max], dtype=torch.int64) # Pre-compute angular frequency bases so tracing uses cached values # instead of recomputing with varying constants per layer. + # After Conv2dSubsampling, T_max → ~(T_max-7)//2 ≈ 1496 frames. + # Each encoder stack further downsamples, so the max seq_len seen by + # any stack is ~1496. We use 1500 to be safe. encoder_model.encoder.warmup_angular_freq_bases( - seq_len=100, left_context_len=0, device=x.device + seq_len=1500, left_context_len=0, device=x.device ) import traceback diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained.py b/egs/librispeech/ASR/zapformer/jit_pretrained.py new file mode 100755 index 0000000000..daa795df50 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer/exp/cpu_jit.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = model.decoder.blank_id + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + logging.info("Constructing Fbank computer") + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = [] + for w in waves: + feat = torchaudio.compliance.kaldi.fbank( + w.unsqueeze(0), + num_mel_bins=80, + sample_frequency=16000, + dither=0, + snip_edges=False, + high_freq=-400, + ) # (num_frames, 80) + features.append(feat.to(device)) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained.py b/egs/librispeech/ASR/zapformer/onnx_pretrained.py index cbbaa27c09..39b5a70fd2 100755 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained.py @@ -71,7 +71,6 @@ from typing import List, Tuple import k2 -import kaldifeat import onnxruntime as ort import torch import torchaudio @@ -363,15 +362,6 @@ def main(): ) logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) logging.info(f"Reading sound files: {args.sound_files}") waves = read_sound_files( @@ -380,7 +370,17 @@ def main(): ) logging.info("Decoding started") - features = fbank(waves) + features = [] + for w in waves: + feat = torchaudio.compliance.kaldi.fbank( + w.unsqueeze(0), + num_mel_bins=80, + sample_frequency=args.sample_rate, + dither=0, + snip_edges=False, + high_freq=-400, + ) # (num_frames, 80) + features.append(feat) feature_lengths = [f.size(0) for f in features] features = pad_sequence( From d284af8a08386b88439f8ecbd5efaaef32cce9ba Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 28 May 2026 14:53:39 +0800 Subject: [PATCH 1175/1191] minor fixes --- egs/librispeech/ASR/zapformer/export-onnx.py | 50 ++++---- egs/librispeech/ASR/zapformer/export.py | 4 +- .../ASR/zapformer/jit_pretrained.py | 4 +- egs/librispeech/ASR/zapformer/onnx_check.py | 120 +----------------- egs/librispeech/ASR/zapformer/pretrained.py | 24 ++-- .../ASR/zapformer/scaling_converter.py | 53 -------- 6 files changed, 39 insertions(+), 216 deletions(-) delete mode 100644 egs/librispeech/ASR/zapformer/scaling_converter.py diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py index db88bc1c0a..3823c66e20 100755 --- a/egs/librispeech/ASR/zapformer/export-onnx.py +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -1,7 +1,21 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) +# Copyright 2021-2026 Xiaomi Corporation (Author: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This script exports a transducer model from PyTorch to ONNX. @@ -12,32 +26,15 @@ ./zapformer/export-onnx.py \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal False \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" \ + --epoch 13 \ + --avg 2 \ + --exp-dir zapformer/exp \ --fp16 True It will generate the following 3 files inside $repo/exp: - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx + - encoder-epoch-13-avg-2.onnx + - decoder-epoch-13-avg-2.onnx + - joiner-epoch-13-avg-2.onnx See ./onnx_pretrained.py and ./onnx_check.py for how to use the exported ONNX models. @@ -54,7 +51,6 @@ import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params from zapformer import Zapformer @@ -545,8 +541,6 @@ def main(): model.to("cpu") model.eval() - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - encoder = OnnxEncoder( encoder=model.encoder, encoder_embed=model.encoder_embed, diff --git a/egs/librispeech/ASR/zapformer/export.py b/egs/librispeech/ASR/zapformer/export.py index bf7ee65208..0e0bca0f42 100755 --- a/egs/librispeech/ASR/zapformer/export.py +++ b/egs/librispeech/ASR/zapformer/export.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Copyright 2021-2026 Xiaomi Corporation (Author: Fangjun Kuang, # Zengwei Yao, # Wei Kang) # @@ -398,8 +398,6 @@ def main(): params.update(vars(args)) device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) logging.info(f"device: {device}") diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained.py b/egs/librispeech/ASR/zapformer/jit_pretrained.py index daa795df50..201204b7a4 100755 --- a/egs/librispeech/ASR/zapformer/jit_pretrained.py +++ b/egs/librispeech/ASR/zapformer/jit_pretrained.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# Copyright 2021-2026 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/librispeech/ASR/zapformer/onnx_check.py b/egs/librispeech/ASR/zapformer/onnx_check.py index f57886c2b3..daca7d81bd 100755 --- a/egs/librispeech/ASR/zapformer/onnx_check.py +++ b/egs/librispeech/ASR/zapformer/onnx_check.py @@ -79,11 +79,9 @@ import argparse import logging -from typing import Tuple -import onnxruntime as ort import torch -# from onnx_pretrained import OnnxModel +from onnx_pretrained import OnnxModel def get_parser(): @@ -122,122 +120,6 @@ def get_parser(): return parser -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - def test_encoder( torch_model: torch.jit.ScriptModule, onnx_model: OnnxModel, diff --git a/egs/librispeech/ASR/zapformer/pretrained.py b/egs/librispeech/ASR/zapformer/pretrained.py index 3dc98085ec..9e859332f8 100755 --- a/egs/librispeech/ASR/zapformer/pretrained.py +++ b/egs/librispeech/ASR/zapformer/pretrained.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# Copyright 2021-2026 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao, Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -114,7 +114,6 @@ from typing import List import k2 -import kaldifeat import torch import torchaudio from beam_search import ( @@ -295,15 +294,6 @@ def main(): model.eval() logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) logging.info(f"Reading sound files: {params.sound_files}") waves = read_sound_files( @@ -312,7 +302,17 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") - features = fbank(waves) + features = [] + for w in waves: + feat = torchaudio.compliance.kaldi.fbank( + w.unsqueeze(0), + num_mel_bins=params.feature_dim, + sample_frequency=params.sample_rate, + dither=0, + snip_edges=False, + high_freq=-400, + ) # (num_frames, feature_dim) + features.append(feat.to(device)) feature_lengths = [f.size(0) for f in features] features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) diff --git a/egs/librispeech/ASR/zapformer/scaling_converter.py b/egs/librispeech/ASR/zapformer/scaling_converter.py deleted file mode 100644 index aaa3a8c1f9..0000000000 --- a/egs/librispeech/ASR/zapformer/scaling_converter.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file provides a convert_scaled_to_non_scaled() function for zapformer. - -Unlike zipformer, zapformer's training-only modules (ScaleLimiter, -CorrelationLimiter, ActivationAndLinear, etc.) already handle ONNX tracing -internally via torch.jit.is_tracing() checks, so no module replacement is -needed at export time. This function is provided for API compatibility. -""" - -import torch.nn as nn - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, - is_onnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - is_onnx: - True if we are going to export the model for ONNX. - Return: - Return the model unchanged. - - Note: zapformer modules already return identity/zero during torch.jit - tracing, so no conversion is necessary. - """ - return model From 9cb74d0565c3e152c976bce8cfb8d7c173da8efd Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 28 May 2026 15:07:12 +0800 Subject: [PATCH 1176/1191] minor fix --- egs/librispeech/ASR/zapformer/export.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/export.py b/egs/librispeech/ASR/zapformer/export.py index 0e0bca0f42..b1bda25d10 100755 --- a/egs/librispeech/ASR/zapformer/export.py +++ b/egs/librispeech/ASR/zapformer/export.py @@ -165,7 +165,6 @@ import k2 import torch -from scaling_converter import convert_scaled_to_non_scaled from torch import Tensor, nn from train import add_model_arguments, get_model, get_params @@ -487,7 +486,6 @@ def main(): model.eval() if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not From de5a49f485ad1ccf1fd0a6343fb89148689b1ec5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 30 May 2026 21:45:42 +0800 Subject: [PATCH 1177/1191] Fix some dtypes in optimizer. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index f3eabdb8e5..fd4619ca51 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -433,8 +433,8 @@ def adam_step(group, state, grad): exp_avg_sq = state["exp_avg_sq"] except KeyError as e: assert step < 2 - exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=COMPUTE_DTYPE) - exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=COMPUTE_DTYPE) + exp_avg = torch.zeros_like(grad) + exp_avg_sq = torch.zeros_like(grad) state["exp_avg"] = exp_avg state["exp_avg_sq"] = exp_avg_sq From 3d67a58bfba5c31e0f854220a133909360787695 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 30 May 2026 21:48:51 +0800 Subject: [PATCH 1178/1191] Update the results. --- egs/librispeech/ASR/RESULTS.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ded7065b38..d0c65377c2 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -20,9 +20,9 @@ copies of the data.) | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| -| greedy_search | 1.83 | 3.75 | --epoch 13 --avg 3 | - +| greedy_search | 1.81 | 3.73 | --epoch 13 --avg 3 | +Note on other results: dev-clean=1.73,dev-other,3.55, giga test=16.69 giga dev=1.733. (i.e. on the model trained with Libri only). ### zipformer (zipformer + pruned-transducer w/ CR-CTC) From ee435926a316afcd3b95236f3f00ad89a0720ccf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 30 May 2026 21:51:31 +0800 Subject: [PATCH 1179/1191] Set base-lr to 0.02. --- egs/librispeech/ASR/zapformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py index 3d21e76be9..642285ddf6 100755 --- a/egs/librispeech/ASR/zapformer/train.py +++ b/egs/librispeech/ASR/zapformer/train.py @@ -456,7 +456,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.00065, help="The base learning rate." + "--base-lr", type=float, default=0.02, help="The base learning rate." ) parser.add_argument( From 6c2e9b65233e1217afd695aff15caec160826a89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 30 May 2026 22:01:52 +0800 Subject: [PATCH 1180/1191] Remove unnecessary state_dict/load_state_dict members. --- .../ASR/zapformer/alternating_spec_augment.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py index 264d72f0b4..e9e2faa83d 100644 --- a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -261,19 +261,6 @@ def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_ return mask_starts, mask_ends - def state_dict(self, **kwargs) -> Dict[str, Any]: - state = { } - for name in ["max_feature_mask_fraction", "num_feature_masks", - "max_frame_mask_fraction", "max_frame_mask_size", "p"]: - state[name] = getattr(self, name) - return state - - - def load_state_dict(self, state_dict: Dict[str, Any]): - for name in ["max_feature_mask_fraction", "num_feature_masks", - "max_frame_mask_fraction", "max_frame_mask_size", "p"]: - if name in state_dict: - setattr(self, name, state_dict[name]) def time_warp_impl(features: torch.Tensor, factor: int, From 881a8edd1733ffcbf3b76a4ea686a1f90d74cd4e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 30 May 2026 22:39:24 +0800 Subject: [PATCH 1181/1191] Fix issue in matrix_shape() pointed out by AI on https://github.com/k2-fsa/icefall/pull/2082 --- egs/librispeech/ASR/zapformer/batched_rubik.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index fd4619ca51..6b21b51899 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -201,8 +201,8 @@ def matrix_shape(shape): cumprod = [ ] numel = 1 for k in shape: - cumprod.append(k) numel = numel * k + cumprod.append(numel) diffs = [ abs(k - numel // k) for k in cumprod ] min_diff = min(diffs) for i in range(len(shape)): From 3e785921b46844c5d6d4de48ac3cd632aeade1bd Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 1 Jun 2026 11:15:22 +0800 Subject: [PATCH 1182/1191] fix streaming jit export --- egs/librispeech/ASR/zapformer/zapformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py index a726c8b7ec..f96dbcf2c2 100644 --- a/egs/librispeech/ASR/zapformer/zapformer.py +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -431,9 +431,10 @@ def streaming_forward( left_context_frames = src_key_padding_mask.shape[1] - orig_seq_len assert left_context_frames == self.left_context_frames[0] if pad > 0: + padded_mask = pad_mask(src_key_padding_mask[:, left_context_frames:], x.shape[0]) + assert padded_mask is not None src_key_padding_mask = torch.cat( - (src_key_padding_mask[:, :left_context_frames], - pad_mask(src_key_padding_mask[:, left_context_frames:], x.shape[0])), + [src_key_padding_mask[:, :left_context_frames], padded_mask], dim=1, ) @@ -1215,7 +1216,7 @@ def streaming_forward( cached_wm_sum: Tensor, cached_wm_num_frames: Tensor, key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: r""" Args: x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, From f96e36e68916fd1e545a7f2bfbf9c1dc66fc8e9c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Jun 2026 13:36:23 +0800 Subject: [PATCH 1183/1191] Fix from master for ctc_loss bug in torch --- egs/librispeech/ASR/zipformer/model.py | 33 +++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 6ef2508192..c7ee0e5a6d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -173,11 +173,22 @@ def forward_ctc( # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). .cpu() activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) return ctc_loss @@ -200,12 +211,22 @@ def forward_cr_ctc( to be un-padded and concatenated within 1 dimension. """ # Compute CTC loss + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). .cpu() activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) From 916a25075e3587572feedc54f36775aae4f5f41e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Jun 2026 13:51:21 +0800 Subject: [PATCH 1184/1191] Take zipformer/model.py from master. --- egs/librispeech/ASR/zipformer/model.py | 33 +++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 6ef2508192..c7ee0e5a6d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -173,11 +173,22 @@ def forward_ctc( # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). .cpu() activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) return ctc_loss @@ -200,12 +211,22 @@ def forward_cr_ctc( to be un-padded and concatenated within 1 dimension. """ # Compute CTC loss + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). .cpu() activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) From ae69eeab2993964831b9d79936f8a67afa9d0260 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 1 Jun 2026 17:19:31 +0800 Subject: [PATCH 1185/1191] fix streaming export and pretrained inference --- .../ASR/zapformer/export-onnx-streaming.py | 783 ++++++++++++++++++ .../ASR/zapformer/jit_pretrained_streaming.py | 86 +- .../zapformer/onnx_pretrained-streaming.py | 194 +++-- 3 files changed, 928 insertions(+), 135 deletions(-) create mode 100755 egs/librispeech/ASR/zapformer/export-onnx-streaming.py diff --git a/egs/librispeech/ASR/zapformer/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py new file mode 100755 index 0000000000..1a4e9bed37 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py @@ -0,0 +1,783 @@ +#!/usr/bin/env python3 +# +# Copyright 2023-2026 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a streaming transducer model from PyTorch to ONNX. + +Usage: + +cd egs/librispeech/ASR + +./zapformer/export-onnx-streaming.py \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 9 \ + --avg 2 \ + --exp-dir zapformer/exp \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 128 + +It will generate the following 3 files inside exp-dir: + + - encoder-epoch-9-avg-2-chunk-32-left-128.onnx + - decoder-epoch-9-avg-2-chunk-32-left-128.onnx + - joiner-epoch-9-avg-2-chunk-32-left-128.onnx +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_model, get_params +from zapformer import Zapformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + add_model_arguments(parser) + + return parser + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zapformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zapformer, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + N = x.size(0) + T = self.chunk_size * 2 + 7 + x_lens = torch.tensor([T] * N, device=x.device) + left_context_len = self.left_context_len + + embed_cache = states[-2] + x, x_lens, new_embed_cache = self.encoder_embed.streaming_forward( + x=x, + x_lens=x_lens, + cache=embed_cache, + ) + assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) + + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + new_processed_lens = processed_lens + x_lens + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) + encoder_caches = states[:-2] + logging.info(f"len_encoder_caches={len(encoder_caches)}") + ( + encoder_out, + encoder_out_lens, + new_encoder_caches, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + caches=encoder_caches, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + + new_states = new_encoder_caches + [ + new_embed_cache, + new_processed_lens, + ] + + return encoder_out, new_states + + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, + states[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs). + states[-1] is processed_lens of shape (batch,). + """ + states = self.encoder.get_init_caches(batch_size, device) + + embed_cache = self.encoder_embed.get_init_cache(batch_size, device) + states.append(embed_cache) + + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) + states.append(processed_lens) + + return states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + logit = encoder_out + decoder_out + logit = 2.0 * self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, + feature_dim: int = 80, + dynamic_batch: bool = True, +) -> None: + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.chunk_size * 2 + T = decode_chunk_len + 7 + + x = torch.rand(1, T, feature_dim, dtype=torch.float32) + init_state = encoder_model.get_init_states() + logging.info(f"len(init_state): {len(init_state)}") + + # Warm up angular freq bases for tracing + left_context_len = encoder_model.left_context_len + ds_factors = encoder_model.encoder.downsampling_factor + max_seq_len = left_context_len + encoder_model.chunk_size + encoder_model.encoder.warmup_angular_freq_bases( + seq_len=max_seq_len, left_context_len=left_context_len, device=x.device + ) + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + # Count total number of layers across all encoder stacks + total_layers = sum(encoder_model.encoder.num_encoder_layers) + logging.info(f"total encoder layers: {total_layers}") + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 9, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_value_{i}" + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv_{i}" + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_norm_stats: (batch_size,) + name = f"cached_norm_stats_{i}" + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_norm_len: (batch_size,) + name = f"cached_norm_len_{i}" + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_attn_wm_sum: (1, batch_size, attn_value_dim) + name = f"cached_attn_wm_sum_{i}" + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_attn_wm_num_frames: (batch_size,) + name = f"cached_attn_wm_num_frames_{i}" + logging.info(f"{name}.shape: {tensors[6].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_conv_wm_sum: (1, batch_size, embed_dim) + name = f"cached_conv_wm_sum_{i}" + logging.info(f"{name}.shape: {tensors[7].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_conv_wm_num_frames: (batch_size,) + name = f"cached_conv_wm_num_frames_{i}" + logging.info(f"{name}.shape: {tensors[8].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + num_encoder_layers = encoder_model.encoder.num_encoder_layers + encoder_dims = encoder_model.encoder.encoder_dim + conv_params = encoder_model.encoder.conv_params + ds = encoder_model.encoder.downsampling_factor + left_context_len_per_stack = [left_context_len // k for k in ds] + query_head_dims = encoder_model.encoder.query_head_dim + value_head_dims = encoder_model.encoder.value_head_dim + num_heads = encoder_model.encoder.num_heads + + meta_data = { + "model_type": "zapformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "streaming zapformer", + "decode_chunk_len": str(decode_chunk_len), + "T": str(T), + "num_encoder_layers": ",".join(map(str, num_encoder_layers)), + "encoder_dims": ",".join(map(str, encoder_dims)), + "conv_params": ",".join(map(str, conv_params)), + "left_context_len": ",".join(map(str, left_context_len_per_stack)), + "query_head_dims": ",".join(map(str, query_head_dims)), + "value_head_dims": ",".join(map(str, value_head_dims)), + "num_heads": ",".join(map(str, num_heads)), + } + + logging.info(f"meta_data: {meta_data}") + + # 9 tensors per layer + for i in range(len(init_state[:-2]) // 9): + build_inputs_outputs(init_state[i * 9 : (i + 1) * 9], i) + + # (batch_size, channels, left_pad, freq) + embed_cache = init_state[-2] + name = "embed_cache" + logging.info(f"{name}.shape: {embed_cache.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = "processed_lens" + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + logging.info(f"input_names: {input_names}") + logging.info(f"output_names: {output_names}") + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + } + if dynamic_batch + else {}, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, + dynamic_batch: bool = True, +) -> None: + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(1, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + } + if dynamic_batch + else {}, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, + dynamic_batch: bool = True, +) -> None: + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + } + if dynamic_batch + else {}, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + str(encoder_filename), + opset_version=opset_version, + feature_dim=params.feature_dim, + dynamic_batch=params.dynamic_batch == 1, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + str(decoder_filename), + opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + str(joiner_filename), + opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + if params.fp16: + logging.info("Generate fp16 models") + + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) + + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_filename_fp16) + + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_filename_fp16) + + # Generate int8 quantization models + if params.enable_int8_quantization: + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py index 9d85756b1e..705a964a8d 100755 --- a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py @@ -45,7 +45,6 @@ import k2 import torch import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature def get_parser(): @@ -148,24 +147,26 @@ def greedy_search( return hyp, decoder_out -def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. +def compute_fbank(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Compute fbank features for the entire waveform at once. + Args: + waveform: + A 1-D float32 tensor of audio samples. + sample_rate: + The sample rate of the audio. Returns: - Return a CPU streaming feature extractor. + Return a 2-D tensor of shape (num_frames, feature_dim). """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - return OnlineFbank(opts) + feat = torchaudio.compliance.kaldi.fbank( + waveform.unsqueeze(0), + num_mel_bins=80, + sample_frequency=sample_rate, + dither=0, + snip_edges=False, + high_freq=-400, + ) + return feat @torch.no_grad() @@ -191,9 +192,7 @@ def main(): token_table = k2.SymbolTable.from_file(args.tokens) context_size = decoder.context_size - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor(args.sample_rate) - + logging.info("Computing fbank features") logging.info(f"Reading sound files: {args.sound_file}") wave_samples = read_sound_files( filenames=[args.sound_file], @@ -201,52 +200,39 @@ def main(): )[0] logging.info(wave_samples.shape) + # Compute all fbank features at once + features = compute_fbank(wave_samples, args.sample_rate) + logging.info(f"features shape: {features.shape}") + logging.info("Decoding started") chunk_length = encoder.chunk_size * 2 - T = chunk_length + encoder.pad_length + T = chunk_length + 7 # Conv2dSubsampling pad_length is a fixed constant logging.info(f"chunk_length: {chunk_length}") logging.info(f"T: {T}") states = encoder.get_init_states(device=device) - tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - chunk = int(0.25 * args.sample_rate) # 0.2 second + num_frames = features.size(0) num_processed_frames = 0 hyp = None decoder_out = None - start = 0 - while start < wave_samples.numel(): - logging.info(f"{start}/{wave_samples.numel()}") - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - online_fbank.accept_waveform( - sampling_rate=args.sample_rate, - waveform=samples, + while num_processed_frames + T <= num_frames: + frames = features[num_processed_frames : num_processed_frames + T].to(device).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32, device=device) + encoder_out, out_lens, states = encoder( + features=frames, + feature_lengths=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device ) - while online_fbank.num_frames_ready - num_processed_frames >= T: - frames = [] - for i in range(T): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) - x_lens = torch.tensor([T], dtype=torch.int32, device=device) - encoder_out, out_lens, states = encoder( - features=frames, - feature_lengths=x_lens, - states=states, - ) - num_processed_frames += chunk_length - - hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device - ) text = "" for i in hyp[context_size:]: diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py index 2d25805842..0e297a5d30 100755 --- a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py @@ -78,7 +78,6 @@ import onnxruntime as ort import torch import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature def get_parser(): @@ -155,14 +154,14 @@ def init_encoder_states(self, batch_size: int = 1): logging.info(f"encoder_meta={encoder_meta}") model_type = encoder_meta["model_type"] - assert model_type == "zapformer2", model_type + assert model_type == "zapformer", model_type decode_chunk_len = int(encoder_meta["decode_chunk_len"]) T = int(encoder_meta["T"]) num_encoder_layers = encoder_meta["num_encoder_layers"] encoder_dims = encoder_meta["encoder_dims"] - cnn_module_kernels = encoder_meta["cnn_module_kernels"] + conv_params = encoder_meta["conv_params"] left_context_len = encoder_meta["left_context_len"] query_head_dims = encoder_meta["query_head_dims"] value_head_dims = encoder_meta["value_head_dims"] @@ -173,7 +172,7 @@ def to_int_list(s): num_encoder_layers = to_int_list(num_encoder_layers) encoder_dims = to_int_list(encoder_dims) - cnn_module_kernels = to_int_list(cnn_module_kernels) + conv_params = to_int_list(conv_params) left_context_len = to_int_list(left_context_len) query_head_dims = to_int_list(query_head_dims) value_head_dims = to_int_list(value_head_dims) @@ -183,7 +182,7 @@ def to_int_list(s): logging.info(f"T: {T}") logging.info(f"num_encoder_layers: {num_encoder_layers}") logging.info(f"encoder_dims: {encoder_dims}") - logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"conv_params: {conv_params}") logging.info(f"left_context_len: {left_context_len}") logging.info(f"query_head_dims: {query_head_dims}") logging.info(f"value_head_dims: {value_head_dims}") @@ -196,35 +195,59 @@ def to_int_list(s): num_layers = num_encoder_layers[i] key_dim = query_head_dims[i] * num_heads[i] embed_dim = encoder_dims[i] - nonlin_attn_head_dim = 3 * embed_dim // 4 value_dim = value_head_dims[i] * num_heads[i] - conv_left_pad = cnn_module_kernels[i] // 2 + conv_left_pad = conv_params[i] - 1 for layer in range(num_layers): + # (left_context_len, batch, key_dim) cached_key = torch.zeros( left_context_len[i], batch_size, key_dim ).numpy() - cached_nonlin_attn = torch.zeros( - 1, batch_size, left_context_len[i], nonlin_attn_head_dim - ).numpy() - cached_val1 = torch.zeros( + # (left_context_len, batch, value_dim) + cached_value = torch.zeros( left_context_len[i], batch_size, value_dim ).numpy() - cached_val2 = torch.zeros( - left_context_len[i], batch_size, value_dim + # (batch, embed_dim, conv_left_pad) + cached_conv = torch.zeros( + batch_size, embed_dim, conv_left_pad + ).numpy() + # cached_norm_stats: (batch,) + cached_norm_stats = torch.zeros(batch_size).numpy() + # cached_norm_len: (batch,) + cached_norm_len = torch.zeros(batch_size).numpy() + # cached_attn_wm_sum: (1, batch, value_dim) + cached_attn_wm_sum = torch.zeros( + 1, batch_size, value_dim + ).numpy() + # cached_attn_wm_num_frames: (batch,) + cached_attn_wm_num_frames = torch.zeros( + batch_size, dtype=torch.int64 + ).numpy() + # cached_conv_wm_sum: (1, batch, embed_dim) + cached_conv_wm_sum = torch.zeros( + 1, batch_size, embed_dim + ).numpy() + # cached_conv_wm_num_frames: (batch,) + cached_conv_wm_num_frames = torch.zeros( + batch_size, dtype=torch.int64 ).numpy() - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, ] - embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() - self.states.append(embed_states) + + # embed_cache: (batch, channels, left_pad, freq) + embed_cache = torch.zeros(batch_size, 128, 6, 19).numpy() + self.states.append(embed_cache) + # processed_lens: (batch,) processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() self.states.append(processed_lens) @@ -267,45 +290,60 @@ def _build_encoder_input_output( encoder_output = ["encoder_out"] def build_inputs_outputs(tensors, i): - assert len(tensors) == 6, len(tensors) + assert len(tensors) == 9, len(tensors) - # (downsample_left, batch_size, key_dim) + # (left_context_len, batch_size, key_dim) name = f"cached_key_{i}" encoder_input[name] = tensors[0] encoder_output.append(f"new_{name}") - # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f"cached_nonlin_attn_{i}" + # (left_context_len, batch_size, value_dim) + name = f"cached_value_{i}" encoder_input[name] = tensors[1] encoder_output.append(f"new_{name}") - # (downsample_left, batch_size, value_dim) - name = f"cached_val1_{i}" + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv_{i}" encoder_input[name] = tensors[2] encoder_output.append(f"new_{name}") - # (downsample_left, batch_size, value_dim) - name = f"cached_val2_{i}" + # (batch_size,) + name = f"cached_norm_stats_{i}" encoder_input[name] = tensors[3] encoder_output.append(f"new_{name}") - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv1_{i}" + # (batch_size,) + name = f"cached_norm_len_{i}" encoder_input[name] = tensors[4] encoder_output.append(f"new_{name}") - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv2_{i}" + # (1, batch_size, value_dim) + name = f"cached_attn_wm_sum_{i}" encoder_input[name] = tensors[5] encoder_output.append(f"new_{name}") - for i in range(len(self.states[:-2]) // 6): - build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + # (batch_size,) + name = f"cached_attn_wm_num_frames_{i}" + encoder_input[name] = tensors[6] + encoder_output.append(f"new_{name}") + + # (1, batch_size, embed_dim) + name = f"cached_conv_wm_sum_{i}" + encoder_input[name] = tensors[7] + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = f"cached_conv_wm_num_frames_{i}" + encoder_input[name] = tensors[8] + encoder_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 9): + build_inputs_outputs(self.states[i * 9 : (i + 1) * 9], i) # (batch_size, channels, left_pad, freq) - name = "embed_states" - embed_states = self.states[-2] - encoder_input[name] = embed_states + name = "embed_cache" + embed_cache = self.states[-2] + encoder_input[name] = embed_cache encoder_output.append(f"new_{name}") # (batch_size,) @@ -397,24 +435,24 @@ def read_sound_files( return ans -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. +def compute_fbank(waveform: torch.Tensor) -> torch.Tensor: + """Compute fbank features for the entire waveform at once. + Args: + waveform: + A 1-D float32 tensor of audio samples. Returns: - Return a CPU streaming feature extractor. + Return a 2-D tensor of shape (num_frames, feature_dim). """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - return OnlineFbank(opts) + feat = torchaudio.compliance.kaldi.fbank( + waveform.unsqueeze(0), + num_mel_bins=80, + sample_frequency=16000, + dither=0, + snip_edges=False, + high_freq=-400, + ) + return feat def greedy_search( @@ -479,17 +517,16 @@ def main(): sample_rate = 16000 - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - logging.info(f"Reading sound files: {args.sound_file}") waves = read_sound_files( filenames=[args.sound_file], expected_sample_rate=sample_rate, )[0] - tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) - wave_samples = torch.cat([waves, tail_padding]) + # Compute all fbank features at once + logging.info("Computing fbank features") + features = compute_fbank(waves) + logging.info(f"features shape: {features.shape}") num_processed_frames = 0 segment = model.segment @@ -499,34 +536,21 @@ def main(): hyp = None decoder_out = None - chunk = int(1 * sample_rate) # 1 second - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, + num_frames = features.size(0) + + while num_processed_frames + segment <= num_frames: + frames = features[num_processed_frames : num_processed_frames + segment] + num_processed_frames += offset + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, ) - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - frames = frames.unsqueeze(0) - encoder_out = model.run_encoder(frames) - hyp, decoder_out = greedy_search( - model, - encoder_out, - context_size, - decoder_out, - hyp, - ) - token_table = k2.SymbolTable.from_file(args.tokens) text = "" From 8f94d85e4cf19fd8ca9d7f73350e8fed8e1fd941 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 1 Jun 2026 13:43:41 +0800 Subject: [PATCH 1186/1191] Use batched_rubik optimizer [muon-core] in zipformer, with interp-cosine LR schedule. --- .../ASR/zapformer/combined_scheduler.py | 32 +++++++---- .../ASR/zipformer/batched_rubik.py | 1 + .../ASR/zipformer/combined_scheduler.py | 1 + egs/librispeech/ASR/zipformer/train.py | 57 +++++++++++-------- 4 files changed, 57 insertions(+), 34 deletions(-) create mode 120000 egs/librispeech/ASR/zipformer/batched_rubik.py create mode 120000 egs/librispeech/ASR/zipformer/combined_scheduler.py diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py index 6a26758897..f3eb6a7332 100644 --- a/egs/librispeech/ASR/zapformer/combined_scheduler.py +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -181,27 +181,39 @@ def get_lr(self): class InterpCosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, - min_factor: float = 0.05, - **kwargs): # takes also batches_per_epoch and num_epochs args. + min_factor: float = 0.0, + half_cosine_scale: float = 0.0, + linear_scale: float = 0.0, + **kwargs): """ - This cosine LR scheduler is halfway between the conventional cosine LR scheduler - that takes the cosine from 0 to pi, and one that takes the cosine from 0 to pi/2. - It inherits from CombinedLRScheduler (see its documentation - to understand general aspects of usage). + This cosine LR scheduler encompasses the conventional cosine LR scheduler + that takes the cosine from 0 to pi (shifted to 0..1), the half-cosine LR + scheduler that takes the cosine from 0 to pi, and the linear LR scheduler + that takes the linear function from 1 to 0. """ self.min_factor = min_factor + self.half_cosine_scale = half_cosine_scale + self.linear_scale = linear_scale super().__init__(*args, **kwargs) def get_lr(self): progress = self.get_progress() - factor = math.cos((math.pi / 2) * progress) - # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate - # between the two. - factor = 0.5 * (factor + factor ** 2) + half_cos = math.cos((math.pi / 2) * progress) + cos = half_cos ** 2 + linear = 1. - progress + + linear_scale = self.linear_scale + half_cosine_scale = self.half_cosine_scale + cosine_scale = 1. - self.half_cosine_scale - linear_scale + assert cosine_scale >= 0.0 + + factor = linear_scale * linear + half_cosine_scale * half_cos + cosine_scale * cos + # apply min_factor via interpolation factor = self.min_factor + factor * (1. - self.min_factor) return [x * factor for x in self.base_lrs] + class HalfCosineLRScheduler(CombinedLRScheduler): def __init__(self, *args, diff --git a/egs/librispeech/ASR/zipformer/batched_rubik.py b/egs/librispeech/ASR/zipformer/batched_rubik.py new file mode 120000 index 0000000000..5c024cfd72 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/batched_rubik.py @@ -0,0 +1 @@ +../zapformer/batched_rubik.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/combined_scheduler.py b/egs/librispeech/ASR/zipformer/combined_scheduler.py new file mode 120000 index 0000000000..04a0322459 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/combined_scheduler.py @@ -0,0 +1 @@ +../zapformer/combined_scheduler.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 6a6ce447e0..d20d996194 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -77,6 +77,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam +from batched_rubik import BatchedRubik from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -105,7 +106,12 @@ torch_autocast, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + +from combined_scheduler import CombinedLRScheduler +from combined_scheduler import InterpCosineLRScheduler + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler, CombinedLRScheduler] def get_adjusted_batch_count(params: AttributeDict) -> float: @@ -357,6 +363,16 @@ def get_parser(): """, ) + parser.add_argument( + "--batches-per-epoch", + type=int, + default=2200, + help="Assumed number of batches per epoch for purposes of setting learning rate; only " + "makes a difference during the first batch, after which an observed value is used. This " + "is the num batches where num_copies==1, i.e. on the first epoch" + ) + + parser.add_argument( "--start-batch", type=int, @@ -384,24 +400,9 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." + "--base-lr", type=float, default=0.02, help="The base learning rate." ) - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) parser.add_argument( "--ref-duration", @@ -1120,7 +1121,7 @@ def save_bad_model(suffix: str = ""): # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) + scheduler.set_batch(batch_idx) scaler.step(optimizer) scaler.update() @@ -1342,13 +1343,21 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, + optimizer = BatchedRubik( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), + lr=params.base_lr, + beta1=0.99, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1) + + # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. + # this configuration is halfway between a linear function (1 to 0) and the conventional + # cosine LR scheduler. It decays to a minimum of 0.025. + scheduler = InterpCosineLRScheduler(optimizer, + min_factor=0.025, + linear_scale=0.5, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1459,7 +1468,7 @@ def remove_short_and_long_utt(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) + scheduler.set_epoch(epoch) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) From 60bcaa76f0fbf9d819711eba6d5db442fe9df7d0 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 8 Jun 2026 10:15:54 +0800 Subject: [PATCH 1187/1191] Add giga/cv test sets for zipformer --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 112 +++++++++++++++ egs/librispeech/ASR/zipformer/decode.py | 127 +++++++++++++++++- 2 files changed, 235 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index d2f6db8335..a9ce6b8d36 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -18,13 +18,21 @@ import argparse import inspect +import glob import logging +import re + from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import numpy as np # to set its random seed + import torch +import lhotse + from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy + from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -497,3 +505,107 @@ def gigaspeech_dev_cuts(self) -> CutSet: def gigaspeech_test_cuts(self) -> CutSet: logging.info("About to get Gigaspeech test cuts") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + + +class GigaSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + - gigaspeech_cuts_L.jsonl.gz + - gigaspeech_cuts_M.jsonl.gz + - gigaspeech_cuts_S.jsonl.gz + - gigaspeech_cuts_XS.jsonl.gz + - gigaspeech_cuts_DEV.jsonl.gz + - gigaspeech_cuts_TEST.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_XL_cuts_split(self) -> CutSet: + logging.info("About to get train-XL cuts") + + filenames = list( + glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" # noqa + ) + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading {len(sorted_filenames)} splits") + + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + + def train_XL_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XL.jsonl.gz" + logging.info(f"About to get train-XL cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_L.jsonl.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_M_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_M.jsonl.gz" + logging.info(f"About to get train-M cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_S_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_S.jsonl.gz" + logging.info(f"About to get train-S cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XS_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XS.jsonl.gz" + logging.info(f"About to get train-XS cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def test_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + logging.info(f"About to get TEST cuts from {f}") + return load_manifest_lazy(f) + + def dev_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + logging.info(f"About to get DEV cuts from {f}") + return load_manifest_lazy(f) + + +class CommonVoice: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cv22-en_cuts_train.jsonl.gz + - cv22-en_cuts_dev.jsonl.gz + - cv22-en_cuts_test.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get train cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_train.jsonl.gz" + ) + + def dev_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get dev cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_dev.jsonl.gz" + ) + + def test_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get test cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_test.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 6462d22f86..ac6a44ae67 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -98,6 +98,7 @@ import logging import math import os +import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -106,7 +107,7 @@ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import CommonVoice, GigaSpeech, LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -142,6 +143,80 @@ LOG_EPS = math.log(1e-10) +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def giga_asr_text_post_processing(text: str) -> str: # only used for gigaspeech + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def giga_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = giga_asr_text_post_processing(" ".join(ref)).split() + new_hyp = giga_asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def cv_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + def normalize(text): + return re.sub(r'[^\w\s]', '', text).upper() + new_results = [] + for key, ref, hyp in results: + new_ref = normalize(" ".join(ref)).split() + new_hyp = normalize(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + def get_parser(): parser = argparse.ArgumentParser( @@ -378,6 +453,20 @@ def get_parser(): help="""Skip scoring, but still save the ASR output (for eval sets).""", ) + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="""If True, decode gigaspeech in addition to librispeech test sets.""", + ) + + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + add_model_arguments(parser) return parser @@ -732,6 +821,10 @@ def save_asr_output( recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) store_transcripts(filename=recogs_filename, texts=results) logging.info(f"The transcripts are stored in {recogs_filename}") @@ -759,6 +852,10 @@ def save_wer_results( logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" @@ -1044,12 +1141,34 @@ def main(): test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + giga_test_dl = librispeech.test_dataloaders(test_cuts) + giga_dev_dl = librispeech.test_dataloaders(dev_cuts) + test_sets += ["giga-dev", "giga-test"] + test_dl += [giga_dev_dl, giga_test_dl] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + cv_test_dl = librispeech.test_dataloaders(test_cuts) + cv_dev_dl = librispeech.test_dataloaders(dev_cuts) + test_sets += ["cv-dev", "cv-test"] + test_dl += [cv_dev_dl, cv_test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( From be8a101a51fecfb539128836e5d6b40ad96260be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Jun 2026 15:16:08 +0800 Subject: [PATCH 1188/1191] Make code more robust w.r.t. COMPUTE_DTYPE. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index 6b21b51899..cd3cd8be2f 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -32,7 +32,7 @@ # from nanochat.common import COMPUTE_DTYPE # except: # from logging import info as print0 -# #COMPUTE_DTYPE = torch.float32 +#COMPUTE_DTYPE = torch.float32 COMPUTE_DTYPE = torch.bfloat16 @@ -357,7 +357,9 @@ def muon_core_step(group, state, grad): def t(x): return torch.tensor(x, device=grad.device, dtype=COMPUTE_DTYPE) - step = muon_step_fused(grad.to(COMPUTE_DTYPE), momentum_buffer, second_momentum_buffer, + + grad = grad.to(COMPUTE_DTYPE) if grad.dtype != COMPUTE_DTYPE else grad.clone() + step = muon_step_fused(grad, momentum_buffer, second_momentum_buffer, t(beta1), t(lr), t(beta2), t(eps), 5, (-1 if rows > cols else -2)) return step.reshape(orig_shape) From a1de0b289961a4b9c528aed2913050435fe42f65 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 9 Jun 2026 15:20:39 +0800 Subject: [PATCH 1189/1191] Remove comment. --- egs/librispeech/ASR/zapformer/batched_rubik.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py index cd3cd8be2f..ca0e92159b 100644 --- a/egs/librispeech/ASR/zapformer/batched_rubik.py +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -27,11 +27,6 @@ from torch import Tensor from torch.optim import Optimizer -# try: -# from nanochat.common import print0 -# from nanochat.common import COMPUTE_DTYPE -# except: -# from logging import info as print0 #COMPUTE_DTYPE = torch.float32 COMPUTE_DTYPE = torch.bfloat16 From bc6955d91bb4391921da10756e267d41377cdcf3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 11 Jun 2026 19:54:28 +0800 Subject: [PATCH 1190/1191] Remove muon.py --- egs/librispeech/ASR/zapformer/muon.py | 284 -------------------------- egs/librispeech/ASR/zipformer/muon.py | 284 -------------------------- 2 files changed, 568 deletions(-) delete mode 100644 egs/librispeech/ASR/zapformer/muon.py delete mode 100644 egs/librispeech/ASR/zipformer/muon.py diff --git a/egs/librispeech/ASR/zapformer/muon.py b/egs/librispeech/ASR/zapformer/muon.py deleted file mode 100644 index df69d1c166..0000000000 --- a/egs/librispeech/ASR/zapformer/muon.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2025 Moonshot AI and the LlamaFactory team. -# -# This code is based on the MoonshotAI's Moonlight library. -# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py -# and the Keller Jordan's Muon library. -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# MIT License -# -# Copyright (c) 2025 Moonshot AI -# Copyright (c) 2024 Keller Jordan -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import math -import torch -import logging -import random - - - - -def norm4(X): - XX = X @ X.T - if random.random() < 0.0001: - norm2 = X.norm() - norm4 = XX.norm().sqrt() - logging.info(f"shape={X.shape}, norm2={norm2} vs norm4={norm4}") - return XX.norm().sqrt() - -def get_muon_shape(shape): - shape = list(shape) - def prod(l): - ans = l[0] - for n in l[1:]: - ans = ans * n - return ans - n = len(shape) - diffs = [ ] - for i in range(1, n): - prod1 = prod(shape[:i]) - prod2 = prod(shape[i:]) - diff = abs(prod1 - prod2) - diffs.append(diff) - min_diff = min(diffs) - for i in range(1, n): - if diffs[i-1] == min_diff: - return prod(shape[:i]), prod(shape[i:]) - -def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int, state: dict) -> "torch.Tensor": - """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. - - We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. - For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing - the slope at zero even beyond the point where the iteration no longer converges all the way to - one everywhere on the interval. This iteration therefore does not produce UV^T but rather something - like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - orig_shape = G.shape - G = G.reshape(get_muon_shape(orig_shape)) - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(0) > G.size(1): - X = X.T - - if "delta2_buffer0" not in state: - state["delta2_buffer0"] = torch.ones(X.shape[0], device=X.device, dtype=X.dtype) - state["delta2_buffer1"] = torch.ones(X.shape[1], device=X.device, dtype=X.dtype) - delta2_buffer0 = state["delta2_buffer0"] - delta2_buffer1 = state["delta2_buffer1"] - - - eps = 1e-7 - - # we'll scale both before and after the newton-schulz - row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt().unsqueeze(-1) * (delta2_buffer1 + eps).sqrt()) - X = X * row_col_scale - - # Ensure spectral 4-norm is at most 1 - X = X / (norm4(X) + eps) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - # the following scales so if the newton-schulz was exact, the elements of X would have unit RMS. - X = X * (max(X.shape[0], X.shape[1]) ** 0.5) - X2 = X ** 2 - beta = 0.98 - delta2_buffer0.mul_(beta).add_(X2.mean(dim=1), alpha=(1 - beta)) - delta2_buffer1.mul_(beta).add_(X2.mean(dim=0), alpha=(1 - beta)) - - X = X * row_col_scale - - if G.size(0) > G.size(1): - X = X.T - - return X.reshape(orig_shape) - - -class Muon(torch.optim.Optimizer): - """Muon - MomentUm Orthogonalized by Newton-schulz. - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - muon_params: The parameters to be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - wd: weight decay for muon and adamw, this is a squared type of weight decay, requires a large value - which dimensionally is like an inverse of a parameter rms - """ - def __init__( - self, - params, - lr=1e-3, - wd=10.0, # weight decay is a squared type, needs larger wd value, - momentum=0.95, - nesterov=True, - ns_steps=5, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - scale_limits=(0.5, 4.0), - ): - defaults = dict( - lr=lr, - wd=wd, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - scale_limits=scale_limits, - ) - super().__init__(params, defaults) - - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - # Muon loop - params = [p for p in group["params"] if p.numel() != max(p.shape, default=1)] - lr = group["lr"] - wd = group["wd"] - momentum = group["momentum"] - min_scale, max_scale = group["scale_limits"] - - # generate weight updates in distributed fashion - for p in params: - # sanity check - g = p.grad - if g is None: - continue - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - state["scale"] = torch.tensor(1.0, device=g.device) # scalar - state["scale_grad_buffer"] = torch.tensor(0.0, device=g.device) # scalar - buf = state["momentum_buffer"] - scale = state["scale"] - scale_grad_buf = state["scale_grad_buffer"] - buf.mul_(momentum).add_(g) - - scale_grad = (g * p.detach()).sum() - scale_grad_buf.mul_(momentum).add_(scale_grad) - - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - eps = 1.0e-08 - - - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], state=state) - - # multiplying by 0.2 is what's left of adjust_lr_for_muon(), - # we used the factor of (max(p.shape[0], p.shape[1]) ** 0.5) inside - # zeropower_via_newtonschulz5. - adjusted_lr = 0.2 * lr - - old_scale = scale.clone() - - scale.add_(scale_grad_buf.sign(), alpha=-lr) - scale.clamp_(min=min_scale, max=max_scale) - - scale_ratio = scale / old_scale - - # apply changes in scale, together with conventional decay. - p.data.mul_(scale_ratio * (1 - (lr * wd) ** 2)) - - # apply update - p.data.add_(u * scale, alpha=-adjusted_lr) - - # Adam backup - params = [p for p in group["params"] if p.numel() == max(p.shape, default=1)] - - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["wd"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - if "step" not in state: - state["step"] = 0 - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - state["step"] += 1 - step = state["step"] - buf1 = state["moment1"] - buf2 = state["moment2"] - buf1.lerp_(g, 1 - beta1) - buf2.lerp_(g.square(), 1 - beta2) - - g = buf1 / (eps + buf2.sqrt()) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - (lr * weight_decay) ** 2) - p.data.add_(g, alpha=-lr / scale) - - return loss diff --git a/egs/librispeech/ASR/zipformer/muon.py b/egs/librispeech/ASR/zipformer/muon.py deleted file mode 100644 index df69d1c166..0000000000 --- a/egs/librispeech/ASR/zipformer/muon.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2025 Moonshot AI and the LlamaFactory team. -# -# This code is based on the MoonshotAI's Moonlight library. -# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py -# and the Keller Jordan's Muon library. -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# MIT License -# -# Copyright (c) 2025 Moonshot AI -# Copyright (c) 2024 Keller Jordan -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import math -import torch -import logging -import random - - - - -def norm4(X): - XX = X @ X.T - if random.random() < 0.0001: - norm2 = X.norm() - norm4 = XX.norm().sqrt() - logging.info(f"shape={X.shape}, norm2={norm2} vs norm4={norm4}") - return XX.norm().sqrt() - -def get_muon_shape(shape): - shape = list(shape) - def prod(l): - ans = l[0] - for n in l[1:]: - ans = ans * n - return ans - n = len(shape) - diffs = [ ] - for i in range(1, n): - prod1 = prod(shape[:i]) - prod2 = prod(shape[i:]) - diff = abs(prod1 - prod2) - diffs.append(diff) - min_diff = min(diffs) - for i in range(1, n): - if diffs[i-1] == min_diff: - return prod(shape[:i]), prod(shape[i:]) - -def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int, state: dict) -> "torch.Tensor": - """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. - - We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. - For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing - the slope at zero even beyond the point where the iteration no longer converges all the way to - one everywhere on the interval. This iteration therefore does not produce UV^T but rather something - like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - orig_shape = G.shape - G = G.reshape(get_muon_shape(orig_shape)) - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(0) > G.size(1): - X = X.T - - if "delta2_buffer0" not in state: - state["delta2_buffer0"] = torch.ones(X.shape[0], device=X.device, dtype=X.dtype) - state["delta2_buffer1"] = torch.ones(X.shape[1], device=X.device, dtype=X.dtype) - delta2_buffer0 = state["delta2_buffer0"] - delta2_buffer1 = state["delta2_buffer1"] - - - eps = 1e-7 - - # we'll scale both before and after the newton-schulz - row_col_scale = 1. / ((delta2_buffer0 + eps).sqrt().unsqueeze(-1) * (delta2_buffer1 + eps).sqrt()) - X = X * row_col_scale - - # Ensure spectral 4-norm is at most 1 - X = X / (norm4(X) + eps) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - # the following scales so if the newton-schulz was exact, the elements of X would have unit RMS. - X = X * (max(X.shape[0], X.shape[1]) ** 0.5) - X2 = X ** 2 - beta = 0.98 - delta2_buffer0.mul_(beta).add_(X2.mean(dim=1), alpha=(1 - beta)) - delta2_buffer1.mul_(beta).add_(X2.mean(dim=0), alpha=(1 - beta)) - - X = X * row_col_scale - - if G.size(0) > G.size(1): - X = X.T - - return X.reshape(orig_shape) - - -class Muon(torch.optim.Optimizer): - """Muon - MomentUm Orthogonalized by Newton-schulz. - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - muon_params: The parameters to be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - wd: weight decay for muon and adamw, this is a squared type of weight decay, requires a large value - which dimensionally is like an inverse of a parameter rms - """ - def __init__( - self, - params, - lr=1e-3, - wd=10.0, # weight decay is a squared type, needs larger wd value, - momentum=0.95, - nesterov=True, - ns_steps=5, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - scale_limits=(0.5, 4.0), - ): - defaults = dict( - lr=lr, - wd=wd, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - scale_limits=scale_limits, - ) - super().__init__(params, defaults) - - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - # Muon loop - params = [p for p in group["params"] if p.numel() != max(p.shape, default=1)] - lr = group["lr"] - wd = group["wd"] - momentum = group["momentum"] - min_scale, max_scale = group["scale_limits"] - - # generate weight updates in distributed fashion - for p in params: - # sanity check - g = p.grad - if g is None: - continue - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - state["scale"] = torch.tensor(1.0, device=g.device) # scalar - state["scale_grad_buffer"] = torch.tensor(0.0, device=g.device) # scalar - buf = state["momentum_buffer"] - scale = state["scale"] - scale_grad_buf = state["scale_grad_buffer"] - buf.mul_(momentum).add_(g) - - scale_grad = (g * p.detach()).sum() - scale_grad_buf.mul_(momentum).add_(scale_grad) - - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - eps = 1.0e-08 - - - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], state=state) - - # multiplying by 0.2 is what's left of adjust_lr_for_muon(), - # we used the factor of (max(p.shape[0], p.shape[1]) ** 0.5) inside - # zeropower_via_newtonschulz5. - adjusted_lr = 0.2 * lr - - old_scale = scale.clone() - - scale.add_(scale_grad_buf.sign(), alpha=-lr) - scale.clamp_(min=min_scale, max=max_scale) - - scale_ratio = scale / old_scale - - # apply changes in scale, together with conventional decay. - p.data.mul_(scale_ratio * (1 - (lr * wd) ** 2)) - - # apply update - p.data.add_(u * scale, alpha=-adjusted_lr) - - # Adam backup - params = [p for p in group["params"] if p.numel() == max(p.shape, default=1)] - - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["wd"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - if "step" not in state: - state["step"] = 0 - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - state["step"] += 1 - step = state["step"] - buf1 = state["moment1"] - buf2 = state["moment2"] - buf1.lerp_(g, 1 - beta1) - buf2.lerp_(g.square(), 1 - beta2) - - g = buf1 / (eps + buf2.sqrt()) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - (lr * weight_decay) ** 2) - p.data.add_(g, alpha=-lr / scale) - - return loss From 7e077af02a3ad93104c09e3a88eee1135c25b166 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 11 Jun 2026 19:55:21 +0800 Subject: [PATCH 1191/1191] take zipformer/train.py from master, move this train.py to train_newoptim.py --- egs/librispeech/ASR/zipformer/train.py | 57 +- .../ASR/zipformer/train_newoptim.py | 1612 +++++++++++++++++ 2 files changed, 1636 insertions(+), 33 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/train_newoptim.py diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index d20d996194..6a6ce447e0 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -77,7 +77,6 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam -from batched_rubik import BatchedRubik from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor @@ -106,12 +105,7 @@ torch_autocast, ) - -from combined_scheduler import CombinedLRScheduler -from combined_scheduler import InterpCosineLRScheduler - - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler, CombinedLRScheduler] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_adjusted_batch_count(params: AttributeDict) -> float: @@ -363,16 +357,6 @@ def get_parser(): """, ) - parser.add_argument( - "--batches-per-epoch", - type=int, - default=2200, - help="Assumed number of batches per epoch for purposes of setting learning rate; only " - "makes a difference during the first batch, after which an observed value is used. This " - "is the num batches where num_copies==1, i.e. on the first epoch" - ) - - parser.add_argument( "--start-batch", type=int, @@ -400,9 +384,24 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.02, help="The base learning rate." + "--base-lr", type=float, default=0.045, help="The base learning rate." ) + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) parser.add_argument( "--ref-duration", @@ -1121,7 +1120,7 @@ def save_bad_model(suffix: str = ""): # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - scheduler.set_batch(batch_idx) + scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() @@ -1343,21 +1342,13 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = BatchedRubik( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), - lr=params.base_lr, - beta1=0.99, + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, ) - - # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. - # this configuration is halfway between a linear function (1 to 0) and the conventional - # cosine LR scheduler. It decays to a minimum of 0.025. - scheduler = InterpCosineLRScheduler(optimizer, - min_factor=0.025, - linear_scale=0.5, - batches_per_epoch=params.batches_per_epoch, - num_epochs=params.num_epochs) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1468,7 +1459,7 @@ def remove_short_and_long_utt(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.set_epoch(epoch) + scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) diff --git a/egs/librispeech/ASR/zipformer/train_newoptim.py b/egs/librispeech/ASR/zipformer/train_newoptim.py new file mode 100755 index 0000000000..d20d996194 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/train_newoptim.py @@ -0,0 +1,1612 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Amir Hussein +# Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset import SpecAugment +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from batched_rubik import BatchedRubik +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + torch_autocast, +) + + +from combined_scheduler import CombinedLRScheduler +from combined_scheduler import InterpCosineLRScheduler + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler, CombinedLRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--batches-per-epoch", + type=int, + default=2200, + help="Assumed number of batches per epoch for purposes of setting learning rate; only " + "makes a difference during the first batch, after which an observed value is used. This " + "is the num batches where num_copies==1, i.e. on the first epoch" + ) + + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.02, help="The base learning rate." + ) + + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional["GradScaler"] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + spec_augment: Optional[SpecAugment] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + # linear warmup + cr_loss_scale = min(batch_idx_train / warm_step, 1.0) * params.cr_loss_scale + loss += cr_loss_scale * cr_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: "GradScaler", + spec_augment: Optional[SpecAugment] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.set_batch(batch_idx) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = BatchedRubik( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), + lr=params.base_lr, + beta1=0.99, + ) + + + # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. + # this configuration is halfway between a linear function (1 to 0) and the conventional + # cosine LR scheduler. It decays to a minimum of 0.025. + scheduler = InterpCosineLRScheduler(optimizer, + min_factor=0.025, + linear_scale=0.5, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + # For CTC `(T - 2) < len(tokens)` is needed. otherwise inf. in loss appears. + # For Transducer `T < len(tokens)` was okay. + if (T - 2) < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training (too many supervision tokens). " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + spec_augment=spec_augment, + ) + + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.set_epoch(epoch) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()