Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions python/cuda_cccl/cuda/compute/_nvtx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
NVTX annotation utilities for cuda.compute module.
Uses NVIDIA green (76B900) color and cuda.compute domain.
"""

import functools

import nvtx

# NVIDIA green color hex value (76B900)
NVIDIA_GREEN = 0x76B900

# Domain name for cuda.compute annotations
COMPUTE_DOMAIN = "cuda.compute"


def annotate(message=None, domain=None, category=None, color=None):
"""
Decorator to annotate functions with NVTX markers.

Args:
message: Optional message to display. If None, uses the function name.
domain: Optional NVTX domain string. Defaults to "cuda.compute".
category: Optional category for the annotation.
color: Optional color in hexadecimal format (0xRRGGBB). Defaults to NVIDIA green (0x76B900).

Returns:
Decorated function with NVTX annotations.
"""

def decorator(func):
# Use function name if no message is provided
annotation_message = message if message is not None else func.__name__
annotation_domain = domain if domain is not None else COMPUTE_DOMAIN
annotation_color = color if color is not None else NVIDIA_GREEN

@functools.wraps(func)
def wrapper(*args, **kwargs):
with nvtx.annotate(
annotation_message,
domain=annotation_domain,
color=annotation_color,
category=category,
):
return func(*args, **kwargs)

return wrapper

return decorator
4 changes: 4 additions & 0 deletions python/cuda_cccl/cuda/compute/algorithms/_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import _cccl_interop as cccl
from .._caching import cache_with_key
from .._cccl_interop import call_build, set_cccl_iterator_state, to_cccl_value_state
from .._nvtx import annotate
from .._utils.protocols import get_data_pointer, get_dtype, validate_and_get_stream
from .._utils.temp_storage_buffer import TempStorageBuffer
from ..iterators._iterators import IteratorBase
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
is_evenly_segmented,
)

@annotate(message="_Histogram.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -134,6 +136,7 @@ def __call__(
return temp_storage_bytes


@annotate()
@cache_with_key(make_cache_key)
def make_histogram_even(
d_samples: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -173,6 +176,7 @@ def make_histogram_even(
)


@annotate()
def histogram_even(
d_samples: DeviceArrayLike | IteratorBase,
d_histogram: DeviceArrayLike,
Expand Down
4 changes: 4 additions & 0 deletions python/cuda_cccl/cuda/compute/algorithms/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
set_cccl_iterator_state,
to_cccl_value_state,
)
from .._nvtx import annotate
from .._utils import protocols
from .._utils.protocols import get_data_pointer, validate_and_get_stream
from .._utils.temp_storage_buffer import TempStorageBuffer
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self.h_init_cccl,
)

@annotate(message="_Reduce.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -119,6 +121,7 @@ def make_cache_key(

# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
@annotate()
@cache_with_key(make_cache_key)
def make_reduce_into(
d_in: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -148,6 +151,7 @@ def make_reduce_into(
return _Reduce(d_in, d_out, op, h_init)


@annotate()
def reduce_into(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
Expand Down
6 changes: 6 additions & 0 deletions python/cuda_cccl/cuda/compute/algorithms/_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
set_cccl_iterator_state,
to_cccl_value_state,
)
from .._nvtx import annotate
from .._utils import protocols
from .._utils.protocols import get_data_pointer, validate_and_get_stream
from .._utils.temp_storage_buffer import TempStorageBuffer
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
case (False, _bindings.InitKind.NO_INIT):
raise ValueError("Exclusive scan with No init value is not supported")

@annotate(message="_Scan.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -201,6 +203,7 @@ def make_cache_key(

# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
@annotate()
@cache_with_key(make_cache_key)
def make_exclusive_scan(
d_in: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -230,6 +233,7 @@ def make_exclusive_scan(
return _Scan(d_in, d_out, op, init_value, False)


@annotate()
def exclusive_scan(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -267,6 +271,7 @@ def exclusive_scan(

# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
@annotate()
@cache_with_key(make_cache_key)
def make_inclusive_scan(
d_in: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -296,6 +301,7 @@ def make_inclusive_scan(
return _Scan(d_in, d_out, op, init_value, True)


@annotate()
def inclusive_scan(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
set_cccl_iterator_state,
to_cccl_value_state,
)
from .._nvtx import annotate
from .._utils import protocols
from .._utils.protocols import (
get_data_pointer,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
self.h_init_cccl,
)

@annotate(message="_SegmentedReduce.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -166,6 +168,7 @@ def make_cache_key(
)


@annotate()
@cache_with_key(make_cache_key)
def make_segmented_reduce(
d_in: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -199,6 +202,7 @@ def make_segmented_reduce(
return _SegmentedReduce(d_in, d_out, start_offsets_in, end_offsets_in, op, h_init)


@annotate()
def segmented_reduce(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
Expand Down
4 changes: 4 additions & 0 deletions python/cuda_cccl/cuda/compute/algorithms/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable

from .._caching import CachableFunction, cache_with_key
from .._nvtx import annotate
from .._utils import protocols
from ..iterators._factories import DiscardIterator
from ..iterators._iterators import IteratorBase
Expand Down Expand Up @@ -60,6 +61,7 @@ def _cccl_always_false(x):
_cccl_always_false, # select_second_part_op - always false
)

@annotate(message="_Select.__call__")
def __call__(
self,
temp_storage,
Expand All @@ -81,6 +83,7 @@ def __call__(
)


@annotate()
@cache_with_key(make_cache_key)
def make_select(
d_in: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -115,6 +118,7 @@ def make_select(
return _Select(d_in, d_out, d_num_selected_out, cond)


@annotate()
def select(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ... import _cccl_interop as cccl
from ..._caching import CachableFunction, cache_with_key
from ..._cccl_interop import call_build, set_cccl_iterator_state
from ..._nvtx import annotate
from ..._utils import protocols
from ..._utils.protocols import (
get_data_pointer,
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(
self.op_wrapper,
)

@annotate(message="_MergeSort.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -153,6 +155,7 @@ def __call__(
return temp_storage_bytes


@annotate()
@cache_with_key(make_cache_key)
def make_merge_sort(
d_in_keys: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -184,6 +187,7 @@ def make_merge_sort(
return _MergeSort(d_in_keys, d_in_items, d_out_keys, d_out_items, op)


@annotate()
def merge_sort(
d_in_keys: DeviceArrayLike | IteratorBase,
d_in_items: DeviceArrayLike | IteratorBase | None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ... import _cccl_interop as cccl
from ..._caching import cache_with_key
from ..._cccl_interop import call_build, set_cccl_iterator_state
from ..._nvtx import annotate
from ..._utils.protocols import (
get_data_pointer,
get_dtype,
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
decomposer_return_type,
)

@annotate(message="_RadixSort.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -164,6 +166,7 @@ def __call__(
return temp_storage_bytes


@annotate()
@cache_with_key(make_cache_key)
def make_radix_sort(
d_in_keys: DeviceArrayLike | DoubleBuffer,
Expand Down Expand Up @@ -195,6 +198,7 @@ def make_radix_sort(
return _RadixSort(d_in_keys, d_out_keys, d_in_values, d_out_values, order)


@annotate()
def radix_sort(
d_in_keys: DeviceArrayLike | DoubleBuffer,
d_out_keys: DeviceArrayLike | None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ... import _cccl_interop as cccl
from ..._caching import cache_with_key
from ..._cccl_interop import call_build, set_cccl_iterator_state
from ..._nvtx import annotate
from ..._utils.protocols import (
get_data_pointer,
get_dtype,
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
self.end_offsets_in_cccl,
)

@annotate(message="_SegmentedSort.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -166,6 +168,7 @@ def make_cache_key(
)


@annotate()
@cache_with_key(make_cache_key)
def make_segmented_sort(
d_in_keys: DeviceArrayLike | DoubleBuffer,
Expand Down Expand Up @@ -209,6 +212,7 @@ def make_segmented_sort(
)


@annotate()
def segmented_sort(
d_in_keys: DeviceArrayLike | DoubleBuffer,
d_out_keys: DeviceArrayLike | None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import _cccl_interop as cccl
from .._caching import CachableFunction, cache_with_key
from .._cccl_interop import call_build, set_cccl_iterator_state
from .._nvtx import annotate
from .._utils import protocols
from .._utils.temp_storage_buffer import TempStorageBuffer
from ..iterators._iterators import IteratorBase
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
self.select_second_part_op_wrapper,
)

@annotate(message="_ThreeWayPartition.__call__")
def __call__(
self,
temp_storage,
Expand Down Expand Up @@ -149,6 +151,7 @@ def __call__(
return temp_storage_bytes


@annotate()
@cache_with_key(make_cache_key)
def make_three_way_partition(
d_in: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -192,6 +195,7 @@ def make_three_way_partition(
)


@annotate()
def three_way_partition(
d_in: DeviceArrayLike | IteratorBase,
d_first_part_out: DeviceArrayLike | IteratorBase,
Expand Down
Loading
Loading