Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def main():
initalize=False,
)
evaluator.evaluate()

import torch
torch.distributed.destroy_process_group()


if __name__ == "__main__":
Expand Down
22 changes: 17 additions & 5 deletions utilization/model/megatron_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import argparse
import os
import sys
Expand All @@ -5,16 +5,19 @@
import importlib
from logging import getLogger
from typing import Iterator, List, Optional, Tuple, Union
from types import MethodType

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.generation.configuration_utils import GenerationConfig as TransformersGenerationConfig

from ..model_enum import MEGATRON_ARGS
from ..utils import GenerationArg, ModelArguments, resolve_generation_args
from .model import Model
from .model_utils.conversation import Conversation
from .model_utils.keywords_criteria import KeyWordsCriteria
from .model_utils.megatron_generation_utils import megatron_generate, gpt_prepare_inputs_for_generation

logger = getLogger(__name__)

Expand Down Expand Up @@ -98,7 +101,8 @@

# Set up model and load checkpoint
if args.megatron_model_provider is None:
if megatron_args.spec == ["megatron.core.models.mamba.mamba_layer_specs", "mamba_stack_spec"]:
# allow for different spec format
if megatron_args.spec is not None and megatron_args.spec[0] == "megatron.core.models.mamba.mamba_layer_specs":
megatron_model_provider = "pretrain_mamba"
else:
megatron_model_provider = "pretrain_gpt"
Expand All @@ -111,6 +115,10 @@
model = model[0]
self.model = model

# dynamically set method with megatron_generate
self.model.generate = MethodType(megatron_generate, self.model)
self.model.prepare_inputs_for_generation = MethodType(gpt_prepare_inputs_for_generation, self.model)

self._tokenizer = get_tokenizer()._tokenizer
self._tokenizer.model_max_length = megatron_args.max_position_embeddings
self.model_max_input_and_output = self.tokenizer.model_max_length
Expand Down Expand Up @@ -358,7 +366,7 @@
[KeyWordsCriteria(self.stop_id_sequences)]
}

self.generation_kwargs = resolve_generation_args(
generation_kwargs = resolve_generation_args(
self.args,
extra_model_args,
MEGATRON_ARGS,
Expand All @@ -368,10 +376,12 @@
"eos_token_id": self.tokenizer.eos_token_id,
},
)
self.stopping_criteria = generation_kwargs.pop('stopping_criteria', [])
self.generation_config = TransformersGenerationConfig(**generation_kwargs)

if len(extra_model_args) > 0:
logger.warning(f"Unused generation arguments: {extra_model_args}")
return self.generation_kwargs
return generation_kwargs

def generation(self,
batched_inputs: List[Conversation],
Expand Down Expand Up @@ -439,15 +449,17 @@
batched_encodings = self.tokenizer(
batched_inputs,
padding=True,
padding_side="left",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
return_token_type_ids=False,
).to(self.device)

batch_outputs = self.model.generate(**batched_encodings,
**self.generation_kwargs)
for criteria in self.generation_kwargs.get("stopping_criteria", []):
stopping_criteria=self.stopping_criteria,
generation_config=self.generation_config)
for criteria in self.stopping_criteria:
if isinstance(criteria, KeyWordsCriteria):
criteria.step()

Expand Down
Loading
Loading