diff --git a/scripts/inference.py b/scripts/inference.py index 27d31a3..d5a81b8 100755 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -148,6 +148,20 @@ def main(): type=str, help="Output directory for debugging", ) + # Adding the next two arguments to improve performance on the GPU + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=None, + help="Target fraction of GPU memory vLLM can use for model + KV cache", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="Maximum sequence length for sizing KV cache", + ) + args = parser.parse_args() images: list[str] = args.images or [] @@ -204,6 +218,8 @@ def main(): revision=args.revision, limit_mm_per_prompt={"image": len(images), "video": len(videos)}, enforce_eager=True, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, ) # Process inputs @@ -239,14 +255,18 @@ def main(): "mm_processor_kwargs": video_kwargs, } outputs = llm.generate([llm_inputs], sampling_params=sampling_params) + print(SEPARATOR) + full_texts = [] for output in outputs[0].outputs: output_text = output.text + full_texts.append(output_text) print("Assistant:") print(textwrap.indent(output_text.rstrip(), " ")) print(SEPARATOR) - result, _ = extract_tagged_text(output_text) + result, _ = extract_tagged_text(full_texts[-1]) + if args.verbose and result: pprint_dict(result, "Result")