Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the Order of GatedAct to be (act, linear) when Pax2TE. #892

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,38 @@ def _generate_ckpt_map(self):
hidden_dim = num_of_head * head_dim
mlp_intermediate_dim = self.model_config.mlp_intermediate_dim

for i in range(self.model_config.num_of_layer):
ckpt_map.update({
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w":
self._get_convert_pkg(
if self.use_gated_act:
ckpt_map[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"],
stack_dim = -2) if self.use_gated_act else \
extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"],
stack_dim = -2)
else:
ckpt_map[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),

for i in range(self.model_config.num_of_layer):
ckpt_map_for_ffn1 = {}
if self.use_gated_act:
ckpt_map_for_ffn1[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"],
stack_dim = -2)
else:
ckpt_map_for_ffn1[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),

ckpt_map.update({
**ckpt_map_for_ffn1,
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel",
Expand Down Expand Up @@ -313,17 +333,28 @@ def _generate_ckpt_map(self):
hidden_dim = num_of_head * head_dim
mlp_intermediate_dim = self.model_config.mlp_intermediate_dim

ckpt_map_for_ffn1 = {}
if self.use_gated_act:
ckpt_map_for_ffn1['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1_gate.linear.w'] = \
self._get_convert_pkg(
f'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(hidden_dim, mlp_intermediate_dim), 0,
extra_src_paths = ['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w'],
stack_dim = -2)
else:
ckpt_map_for_ffn1['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w'] = \
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(num_of_layer, hidden_dim, mlp_intermediate_dim), 1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])))

ckpt_map.update({
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_bias',
(num_of_layer, mlp_intermediate_dim), None,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(num_of_layer, hidden_dim, mlp_intermediate_dim), 1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
**ckpt_map_for_ffn1,
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_bias',
Expand Down
Loading