-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathforward_pass.py
102 lines (83 loc) · 3.34 KB
/
forward_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import gin
import jax.numpy as jnp
import numpy as np
import torch
import t5x
from transformers import AutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM
def main(config_file: str, checkpoint_dir: str, hf_model_path: str, run_torch: bool, seq_length: int) -> None:
# Prepare input
shape = [2, seq_length]
encoder_input_tokens = np.ones(shape, dtype=np.int32)
decoder_input_tokens = np.ones(shape, dtype=np.int32)
decoder_target_tokens = np.ones(shape, dtype=np.int32)
################
## FlaxFormer ##
################
# Parse config file
gin.parse_config_file(config_file)
gin.finalize()
# Get model
model_config_ref = gin.query_parameter("%MODEL")
model = model_config_ref.scoped_configurable_fn()
# Load checkpoint
t5x_checkpoint = t5x.checkpoints.load_t5x_checkpoint(checkpoint_dir)
# Run forward pass
print("~~~~~~~~~~ FlaxForrmer ~~~~~~~~~~~~")
try:
embeddings = t5x_checkpoint["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
print("FlaxFormer global relpos:", embeddings.sum(), embeddings.shape)
except:
pass
output = model.module.apply(
{"params": t5x_checkpoint["target"]},
encoder_input_tokens=encoder_input_tokens,
decoder_input_tokens=decoder_input_tokens,
decoder_target_tokens=decoder_target_tokens,
enable_dropout=False,
)
# Print output shape
print(output.shape)
print("~~~~~~~~~~~~~~~~~~~~~~")
#################
## HuggingFace ##
#################
if run_torch:
pt_model = AutoModelForSeq2SeqLM.from_pretrained(hf_model_path)
print("~~~~~~~~~ HF PyTorch ~~~~~~~~~~~~~")
with torch.no_grad():
pt_output = pt_model(
input_ids=torch.from_numpy(encoder_input_tokens).long(),
labels=torch.from_numpy(decoder_target_tokens).long(),
).logits
print(pt_output.shape)
print("~~~~~~~~~~~~~~~~~~~~~~")
flax_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(hf_model_path)
print("~~~~~~~~~ HF Flax ~~~~~~~~~~~~~")
try:
flax_embeddings = flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["TransientGlobalSelfAttention"][
"global_relative_attention_bias"
]["embedding"]
print("HF Flax global relpos:", flax_embeddings.sum(), flax_embeddings.shape)
except:
pass
flax_output = flax_model(input_ids=encoder_input_tokens, decoder_input_ids=decoder_target_tokens).logits
print(flax_output.shape)
print("~~~~~~~~~~~~~~~~~~~~~~")
### Compare outputs ###
print("FlaxFormer output:", output.sum())
if run_torch:
print("HF PyTorch output:", pt_output.sum())
print("HF Flax output:", flax_output.sum())
# Compare argmax
print("FlaxFormer output:", jnp.argmax(output, axis=-1).sum())
print("HF Flax output:", jnp.argmax(flax_output, axis=-1).sum())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config-file")
parser.add_argument("--checkpoint-dir")
parser.add_argument("--hf-model-path")
parser.add_argument("--run-torch", action="store_true")
parser.add_argument("--seq-length", type=int, default=10)
args = parser.parse_args()
main(args.config_file, args.checkpoint_dir, args.hf_model_path, args.run_torch, args.seq_length)