diff --git a/scgpt/model/generation_model.py b/scgpt/model/generation_model.py index 62cef1fa..59b7514d 100644 --- a/scgpt/model/generation_model.py +++ b/scgpt/model/generation_model.py @@ -74,6 +74,7 @@ def __init__( use_fast_transformer = False self.use_fast_transformer = use_fast_transformer + # STEP 1: EMBED input vectors: gene tokens, binned expressions and condition vector (which is perturbation here) self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) self.value_encoder = ContinuousValueEncoder(d_model, dropout) self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id) @@ -81,6 +82,8 @@ def __init__( print("Using simple batchnorm instead of domain specific batchnorm") self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) + # STEP 2: create encoder with number of blocks defined in parameters d_model=512, nhead=8, etc + # use built-in standard attention encoder from either fast_transformer, flash-attention or Torch if use_fast_transformer: if fast_transformer_backend == "linear": self.transformer_encoder = FastTransformerEncoderWrapper( @@ -102,11 +105,14 @@ def __init__( ) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + # STEP 3: create several DECODERs # self.decoder = nn.Linear(d_model, 1) self.decoder = ExprDecoder( d_model, explicit_zero_prob=explicit_zero_prob, ) + + # decoder for classification task self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) if do_mvc: self.mvc_decoder = MVCDecoder( @@ -131,13 +137,21 @@ def _encode( input_pert_flags, src_key_padding_mask: Tensor, ) -> Tensor: + + # EMBED all input vectors src = self.encoder(src) # (batch, seq_len, embsize) self.cur_gene_token_embs = src values = self.value_encoder(values) # (batch, seq_len, embsize) perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize) + + # SUM UP: collapse all of them into a single vector (see article) total_embs = src + values + perts + # hotfix to be able to install later version of flashattention + # to cuda 12: https://github.com/bowang-lab/scGPT/issues/69 total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) + + # ENCODE with all encoding blocks (12 by default) output = self.transformer_encoder( total_embs, src_key_padding_mask=src_key_padding_mask ) @@ -203,10 +217,16 @@ def forward( do_sample = True logger.warning("Auto set do_sample to True when model is in eval mode.") + # STEP 1: ENCODE -> Embedding and all blocks inside !2 blocks by default) transformer_output = self._encode( src, values, input_pert_flags, src_key_padding_mask ) + + # STEP 2: DECODE -> ExprDecoder: Linear/RelU/Linear/relu/Linear + # use different decoders depending on task: simple masking modelling / classification etc output = {} + + # Masked Language Modeling (MLM) mlm_output = self.decoder(transformer_output) if self.explicit_zero_prob and do_sample: bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) @@ -217,8 +237,12 @@ def forward( output["mlm_zero_probs"] = mlm_output["zero_probs"] cell_emb = self._get_cell_emb_from_layer(transformer_output, values) + + # if celltype classification objective if CLS: output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) + + # if Masked value prediction for cell embedding if MVC: mvc_output = self.mvc_decoder( cell_emb, @@ -231,6 +255,8 @@ def forward( output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) if self.explicit_zero_prob: output["mvc_zero_probs"] = mvc_output["zero_probs"] + + # if Elastic cell similarity objective if ECS: # Here using customized cosine similarity instead of F.cosine_similarity # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 @@ -246,6 +272,13 @@ def forward( output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) + # output might contain: + # output["mlm_output"] + # output["mlm_zero_probs"] + # output["cls_output"] + # output["mvc_output"] + # output["mvc_zero_probs"] + # output["loss_ecs"] return output def encode_batch(