Skip to content

E2E training procedure #67

@elephantmipt

Description

@elephantmipt

Hi,

Thank you for sharing the code from your insightful paper!

I'm attempting to train the model using the end-to-end (e2e) setup, and I've encountered an issue related to embeddings. As I understand, you're utilizing the TextDataset_NoCache class for the dataset, which comprises the model's embedding.

class TextDataset_NoCache(Dataset):
def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet',
classes=None, shard=0, num_shards=1, eigen_transform=None,
mapping_func=None, model_emb=None):
super().__init__()
self.resolution = resolution
self.text_datasets = text_datasets
self.length = len(self.text_datasets['train'])
self.model_arch = model_arch
self.data_args = data_args
print(self.resolution)
self.eigen_transform = eigen_transform
self.mapping_func = mapping_func
self.model_emb = model_emb

In the training script, you're passing model=None to the load_data_text function.

if args.experiment == 'random1':
args.experiment = 'random'
print('loading from the vocabs here.')
assert args.in_channel == 64
assert args.modality == 'roc'
model22 = torch.nn.Embedding(args.vocab_size, args.in_channel)
model22_weight = torch.load('predictability/diffusion_models_v7/diff_roc-aug_pad_rand64_'
'transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e/'
'ema_0.9999_200000.pt', map_location='cpu')['word_embedding.weight']
model22.weight = model22_weight
model22.weight.requires_grad=False
else:
model22 = None
data = load_data_text(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
data_args = args,
task_mode=args.modality,
padding_mode=args.padding_mode, #block, pad
load_vocab=rev_tokenizer,
model=model22,
)

I assume that the embeddings are initialized at:

https://github.com/XiangLi1999/Diffusion-LM/blob/main/improved-diffusion/improved_diffusion/text_datasets.py

However, in the e2e setup, it seems logical that one would want to use the continuously updated embeddings from the model. Looking through the training loop, I couldn't find any indication that the embeddings are updated from the model after each gradient step. Could you please shed light on how it's feasible to train embeddings end-to-end when the embeddings are housed within the dataset class?

Thank you for your time and clarification!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions