Skip to content

Update run.py #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,66 @@
# limitations under the License.

import logging

from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model


# Path to the checkpoint directory
CKPT_PATH = "./checkpoints/"


def main():
# Initialize model configuration
grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024,
vocab_size=128 * 1024, # 128K vocabulary size
pad_token=0,
eos_token=2,
sequence_len=8192,
sequence_len=8192, # Sequence length
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add a comment that repeats the parameter name? Is there genuine concern that "sequence_len" isn't already clear?

embedding_init_scale=1.0,
output_multiplier_scale=0.5773502691896257,
embedding_multiplier_scale=78.38367176906169,
model=TransformerConfig(
emb_size=48 * 128,
emb_size=48 * 128, # Embedding size
widening_factor=8,
key_size=128,
num_q_heads=48,
num_kv_heads=8,
num_layers=64,
num_q_heads=48, # Query heads
num_kv_heads=8, # Key/Value heads
num_layers=64, # Number of layers
attn_output_multiplier=0.08838834764831845,
shard_activations=True,
# MoE.
num_experts=8,
num_selected_experts=2,
# Activation sharding.
num_experts=8, # Mixture of Experts (MoE)
num_selected_experts=2, # Selected experts for MoE
data_axis="data",
model_axis="model",
),
)
inference_runner = InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
local_mesh_config=(1, 8),
between_hosts_config=(1, 1),
)
inference_runner.initialize()
gen = inference_runner.run()

inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
try:
# Initialize the inference runner with the model and configurations
inference_runner = InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
bs_per_device=0.125, # Batch size per device
checkpoint_path=CKPT_PATH,
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
local_mesh_config=(1, 8), # Configuration for the local execution mesh
between_hosts_config=(1, 1), # Configuration for between-host execution
)
inference_runner.initialize()
except Exception as e:
logging.error(f"Failed to initialize the inference runner: {e}")
return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should raise the exception as an error will cause the inference_runner to not instantiate. Returning nothing and logging the error is not sufficient enough for handling an error like this at startup.


try:
gen = inference_runner.run()

inp = "The answer to life the universe and everything is of course"
output = sample_from_model(gen, inp, max_len=100, temperature=0.01)
print(f"Output for prompt: '{inp}':\n{output}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code does not place the output on a new line. Can you be sure this won't change behavior for clients utilizing this output for something?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested the code?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Irrelevant. You changed behavior without test coverage.

except Exception as e:
logging.error(f"Failed during model inference: {e}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also raise


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
Expand Down