|
1 | 1 | import asyncio |
2 | 2 | import json |
3 | 3 | import threading |
| 4 | +import time |
4 | 5 | import weakref |
5 | 6 | from dataclasses import dataclass, field |
6 | 7 | from queue import Empty, Queue |
|
11 | 12 | import torch |
12 | 13 | import torch.nn.functional as F |
13 | 14 |
|
| 15 | +from tensorrt_llm.llmapi import tracing |
| 16 | + |
14 | 17 | try: |
15 | 18 | import ray |
16 | 19 | except ModuleNotFoundError: |
@@ -268,6 +271,7 @@ def __init__(self, |
268 | 271 | self.avg_decoded_tokens_per_iter: Optional[float] = None |
269 | 272 | self._done = False |
270 | 273 | self.metrics_dict = {} |
| 274 | + self.trace_headers: Optional[dict[str, str]] = None |
271 | 275 |
|
272 | 276 | if ray_queue is not None: |
273 | 277 | if has_event_loop(): |
@@ -436,6 +440,7 @@ def _handle_sequence(self, |
436 | 440 | raise ValueError( |
437 | 441 | f"Unknown finish reason: {finish_reasons[src_idx]}") |
438 | 442 | self.record_stats(output, req_perf_metrics_dict) |
| 443 | + self.do_tracing(output, req_perf_metrics_dict) |
439 | 444 |
|
440 | 445 | @print_traceback_on_error |
441 | 446 | @nvtx_range_debug("handle_response", |
@@ -472,7 +477,7 @@ def _handle_response(self, |
472 | 477 | self._outputs[0].disaggregated_params = disaggregated_params |
473 | 478 |
|
474 | 479 | if response.metrics: |
475 | | - self.metrics_dict = response.metrics |
| 480 | + self.metrics_dict.update(response.metrics) |
476 | 481 |
|
477 | 482 | if response.error: |
478 | 483 | if self._background_error_handler is not None and ( |
@@ -570,7 +575,110 @@ def record_stats(self, |
570 | 575 | stats, len(output.token_ids), self.sampling_params.n > 1) |
571 | 576 | if processed_metrics_stat: |
572 | 577 | metrics_stats.update(processed_metrics_stat) |
573 | | - self.metrics_dict = metrics_stats |
| 578 | + self.metrics_dict.update(metrics_stats) |
| 579 | + |
| 580 | + def do_tracing( |
| 581 | + self, |
| 582 | + output: CompletionOutput, |
| 583 | + req_perf_metrics_dict: Optional[dict[str, float]] = None, |
| 584 | + ) -> None: |
| 585 | + """Perform distributed tracing for the generation request. |
| 586 | +
|
| 587 | + Args: |
| 588 | + output (CompletionOutput): The output of the generation result. |
| 589 | + req_perf_metrics_dict (Optional[dict[str, float]]): Request performance metrics. Defaults to None. |
| 590 | + """ |
| 591 | + if not tracing.global_otlp_tracer(): |
| 592 | + return |
| 593 | + |
| 594 | + metrics_dict = self.metrics_dict |
| 595 | + if not metrics_dict or not req_perf_metrics_dict: |
| 596 | + # Insufficient request metrics available; trace generation aborted. |
| 597 | + tracing.insufficient_request_metrics_warning() |
| 598 | + return |
| 599 | + |
| 600 | + trace_context = tracing.extract_trace_context(self.trace_headers) |
| 601 | + sampling_params = self.sampling_params |
| 602 | + |
| 603 | + # Since arrival_time and other timing metrics are based on different time origins, |
| 604 | + # we need to apply corrections to align them with absolute timestamps |
| 605 | + time_correction = time.time() - time.monotonic() |
| 606 | + arrival_time = req_perf_metrics_dict.get( |
| 607 | + RequestEventTiming.ARRIVAL_TIME, 0) |
| 608 | + |
| 609 | + with tracing.global_otlp_tracer().start_as_current_span( |
| 610 | + "llm_request", |
| 611 | + kind=tracing.SpanKind.SERVER, |
| 612 | + context=trace_context, |
| 613 | + start_time=int((arrival_time + time_correction) * 1e9), |
| 614 | + ) as span: |
| 615 | + |
| 616 | + def safe_set_attr(span, attr, value): |
| 617 | + if value is not None: |
| 618 | + span.set_attribute(attr, value) |
| 619 | + |
| 620 | + safe_set_attr(span, |
| 621 | + tracing.SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, |
| 622 | + sampling_params.temperature) |
| 623 | + safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_P, |
| 624 | + sampling_params.top_p) |
| 625 | + safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_K, |
| 626 | + sampling_params.top_k) |
| 627 | + safe_set_attr( |
| 628 | + span, |
| 629 | + tracing.SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, |
| 630 | + sampling_params.max_tokens, |
| 631 | + ) |
| 632 | + safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_N, |
| 633 | + sampling_params.n) |
| 634 | + safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_ID, |
| 635 | + self.id) |
| 636 | + if prompt_token_ids := getattr(self, "prompt_token_ids", None): |
| 637 | + safe_set_attr(span, |
| 638 | + tracing.SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, |
| 639 | + len(prompt_token_ids)) |
| 640 | + safe_set_attr(span, |
| 641 | + tracing.SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, |
| 642 | + output.length) |
| 643 | + safe_set_attr( |
| 644 | + span, tracing.SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, |
| 645 | + metrics_dict.get(MetricNames.TTFT, -1)) |
| 646 | + safe_set_attr(span, tracing.SpanAttributes.GEN_AI_LATENCY_E2E, |
| 647 | + metrics_dict.get(MetricNames.E2E, -1)) |
| 648 | + safe_set_attr(span, |
| 649 | + tracing.SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, |
| 650 | + metrics_dict.get(MetricNames.REQUEST_QUEUE_TIME, -1)) |
| 651 | + safe_set_attr( |
| 652 | + span, tracing.SpanAttributes.GEN_AI_RESPONSE_FINISH_REASONS, |
| 653 | + json.dumps([output.finish_reason]) |
| 654 | + if output.finish_reason else None) |
| 655 | + safe_set_attr( |
| 656 | + span, |
| 657 | + tracing.SpanAttributes.GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME, |
| 658 | + req_perf_metrics_dict.get( |
| 659 | + RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) - |
| 660 | + req_perf_metrics_dict.get( |
| 661 | + RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0)) |
| 662 | + |
| 663 | + if req_perf_metrics_dict.get( |
| 664 | + RequestEventTiming.KV_CACHE_TRANSFER_START, |
| 665 | + 0) and req_perf_metrics_dict.get( |
| 666 | + RequestEventTiming.KV_CACHE_TRANSFER_END, 0): |
| 667 | + tracing.add_event( |
| 668 | + tracing.SpanEvents.KV_CACHE_TRANSFER_START, |
| 669 | + timestamp=int((req_perf_metrics_dict.get( |
| 670 | + RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0) + |
| 671 | + time_correction) * 1e9)) |
| 672 | + tracing.add_event( |
| 673 | + tracing.SpanEvents.KV_CACHE_TRANSFER_END, |
| 674 | + attributes={ |
| 675 | + "kv_cache_size": |
| 676 | + req_perf_metrics_dict.get( |
| 677 | + RequestEventTiming.KV_CACHE_SIZE, 0) |
| 678 | + }, |
| 679 | + timestamp=int((req_perf_metrics_dict.get( |
| 680 | + RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) + |
| 681 | + time_correction) * 1e9)) |
574 | 682 |
|
575 | 683 |
|
576 | 684 | class DetokenizedGenerationResultBase(GenerationResultBase): |
@@ -688,6 +796,7 @@ def __init__( |
688 | 796 | self.disaggregated_params = disaggregated_params |
689 | 797 | # minimal sampling params needed for logprob calculation |
690 | 798 | self._logprob_params = logprob_params |
| 799 | + self.trace_headers = generation_request.trace_headers |
691 | 800 |
|
692 | 801 | # for aborting the request |
693 | 802 | self._executor: Optional[weakref.ReferenceType[ |
|
0 commit comments