Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ def main():
type=str,
help="Output directory for debugging",
)
# Adding the next two arguments to improve performance on the GPU
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=None,
help="Target fraction of GPU memory vLLM can use for model + KV cache",
)
parser.add_argument(
"--max-model-len",
type=int,
default=None,
help="Maximum sequence length for sizing KV cache",
)
Comment thread
paularamo marked this conversation as resolved.

args = parser.parse_args()

images: list[str] = args.images or []
Expand Down Expand Up @@ -204,6 +218,8 @@ def main():
revision=args.revision,
limit_mm_per_prompt={"image": len(images), "video": len(videos)},
enforce_eager=True,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
)

# Process inputs
Expand Down Expand Up @@ -239,14 +255,18 @@ def main():
"mm_processor_kwargs": video_kwargs,
}
outputs = llm.generate([llm_inputs], sampling_params=sampling_params)

print(SEPARATOR)
full_texts = []
for output in outputs[0].outputs:
output_text = output.text
full_texts.append(output_text)
print("Assistant:")
print(textwrap.indent(output_text.rstrip(), " "))
print(SEPARATOR)

result, _ = extract_tagged_text(output_text)
result, _ = extract_tagged_text(full_texts[-1])

if args.verbose and result:
pprint_dict(result, "Result")

Expand Down