Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "infinicore/device_event.hpp"
#include "infinicore/nn.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/tensor.hpp"
10 changes: 10 additions & 0 deletions include/infinicore/context/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ void memcpyD2H(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size);
void memcpyH2H(void *dst, const void *src, size_t size);

// Timing APIs for performance measurement
infinirtEvent_t createEvent();
infinirtEvent_t createEventWithFlags(uint32_t flags);
void recordEvent(infinirtEvent_t event, infinirtStream_t stream = nullptr);
bool queryEvent(infinirtEvent_t event);
void synchronizeEvent(infinirtEvent_t event);
void destroyEvent(infinirtEvent_t event);
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);

} // namespace context

} // namespace infinicore
125 changes: 125 additions & 0 deletions include/infinicore/device_event.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#pragma once

#include "device.hpp"
#include "infinirt.h"
#include <memory>
#include <stdexcept>

namespace infinicore {

/**
* @brief A device event for timing operations and synchronization across devices.
*
* Similar to torch.cuda.Event, this class provides functionality to:
* - Record events on specific device streams
* - Synchronize with events
* - Measure elapsed time between events
* - Query event completion status
* - Make streams wait for events
*/
class DeviceEvent {
private:
infinirtEvent_t event_; // Underlying event handle
Device device_; // Device where this event was created
bool is_recorded_; // Whether the event has been recorded

public:
/**
* @brief Construct a new DeviceEvent on the current device.
*/
DeviceEvent();

/**
* @brief Construct a new DeviceEvent on the current device with specific flags.
* @param flags Event creation flags (e.g., for timing, blocking sync)
*/
explicit DeviceEvent(uint32_t flags);

/**
* @brief Construct a new DeviceEvent on a specific device.
* @param device Target device for this event
*/
explicit DeviceEvent(Device device);

/**
* @brief Construct a new DeviceEvent on a specific device with flags.
* @param device Target device for this event
* @param flags Event creation flags
*/
DeviceEvent(Device device, uint32_t flags);

// Disallow copying
DeviceEvent(const DeviceEvent &) = delete;
DeviceEvent &operator=(const DeviceEvent &) = delete;

/**
* @brief Move constructor.
*/
DeviceEvent(DeviceEvent &&other) noexcept;

/**
* @brief Move assignment operator.
*/
DeviceEvent &operator=(DeviceEvent &&other) noexcept;

/**
* @brief Destroy the DeviceEvent and release underlying resources.
*/
~DeviceEvent();

/**
* @brief Record the event on the current stream of its device.
*/
void record();

/**
* @brief Record the event on a specific stream.
* @param stream Stream to record the event on
*/
void record(infinirtStream_t stream);

/**
* @brief Wait for the event to complete (blocking).
*/
void synchronize();

/**
* @brief Check if the event has been completed.
* @return true if completed, false otherwise
*/
bool query() const;

/**
* @brief Calculate elapsed time between this event and another event (in milliseconds).
* @param other The other event to compare with
* @return Elapsed time in milliseconds
* @throws std::runtime_error if events are on different devices or not recorded
*/
float elapsed_time(const DeviceEvent &other) const;

/**
* @brief Make a stream wait for this event to complete.
* @param stream Stream to make wait for this event (nullptr for current stream)
*/
void wait(infinirtStream_t stream = nullptr) const;

/**
* @brief Get the device where this event was created.
* @return Device associated with this event
*/
Device device() const { return device_; }

/**
* @brief Get the underlying event handle.
* @return Raw event handle
*/
infinirtEvent_t get() const { return event_; }

/**
* @brief Check if the event has been recorded.
* @return true if recorded, false otherwise
*/
bool is_recorded() const { return is_recorded_; }
};

} // namespace infinicore
10 changes: 10 additions & 0 deletions include/infinirt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __INFINIRT_API_H__

#include "infinicore.h"
#include <stdint.h>

typedef void *infinirtStream_t;
typedef void *infinirtEvent_t;
Expand All @@ -27,11 +28,20 @@ typedef enum {
INFINIRT_EVENT_NOT_READY = 1,
} infinirtEventStatus_t;

// Event flags for precise timing
typedef enum {
INFINIRT_EVENT_DEFAULT = 0x0, // Default event creation flags
INFINIRT_EVENT_DISABLE_TIMING = 0x1, // Event will not record timing data
INFINIRT_EVENT_BLOCKING_SYNC = 0x2, // Event uses blocking synchronization
} infinirtEventFlags_t;

__C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr);
__C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags);
__C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream);
__C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr);
__C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end);

// Memory
typedef enum {
Expand Down
19 changes: 19 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import contextlib

import infinicore.nn as nn

# Import context functions
from infinicore.context import (
get_device,
get_device_count,
get_stream,
set_device,
sync_device,
sync_stream,
)
from infinicore.device import device
from infinicore.device_event import DeviceEvent
from infinicore.dtype import (
bfloat16,
bool,
Expand Down Expand Up @@ -47,8 +58,16 @@
"nn",
# Classes.
"device",
"DeviceEvent",
"dtype",
"Tensor",
# Context functions.
"get_device",
"get_device_count",
"get_stream",
"set_device",
"sync_device",
"sync_stream",
# Data Types.
"bfloat16",
"bool",
Expand Down
50 changes: 50 additions & 0 deletions python/infinicore/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from infinicore.lib import _infinicore


def get_device():
"""Get the current active device.

Returns:
device: The current active device object
"""
return _infinicore.get_device()


def get_device_count(device_type):
"""Get the number of available devices of a specific type.

Args:
device_type (str): The type of device to count (e.g., "cuda", "cpu", "npu")

Returns:
int: The number of available devices of the specified type
"""
return _infinicore.get_device_count(device_type)


def set_device(device):
"""Set the current active device.

Args:
device: The device to set as active
"""
_infinicore.set_device(device._underlying)


def sync_stream():
"""Synchronize the current stream."""
_infinicore.sync_stream()


def sync_device():
"""Synchronize the current device."""
_infinicore.sync_device()


def get_stream():
"""Get the current stream.

Returns:
stream: The current stream object
"""
return _infinicore.get_stream()
95 changes: 95 additions & 0 deletions python/infinicore/device_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import infinicore.device
from infinicore.lib import _infinicore


class DeviceEvent:
"""A device event for timing operations and synchronization across devices.

Similar to torch.cuda.Event, this class provides functionality to:
- Record events on specific device streams
- Synchronize with events
- Measure elapsed time between events
- Query event completion status
- Make streams wait for events

Args:
device: Target device for this event. If None, uses current device.
flags: Event creation flags (e.g., for timing, blocking sync). Default is 0.
enable_timing: Whether the event should be created with timing enabled.
"""

def __init__(self, device=None, enable_timing=True, flags=0):
if not enable_timing:
# You might want to handle this differently based on your flag system
flags = flags # Adjust flags if timing is disabled

if device is None:
# Use current device
if flags == 0:
self._underlying = _infinicore.DeviceEvent()
else:
self._underlying = _infinicore.DeviceEvent(flags)
elif flags == 0:
# Construct with device only
self._underlying = _infinicore.DeviceEvent(device._underlying)
else:
# Construct with both device and flags
self._underlying = _infinicore.DeviceEvent(device._underlying, flags)

def record(self, stream=None):
"""Record the event.

Args:
stream: Stream to record the event on. If None, uses current stream.
"""
if stream is None:
self._underlying.record()
else:
self._underlying.record(stream)

def synchronize(self):
"""Wait for the event to complete (blocking)."""
self._underlying.synchronize()

def query(self):
"""Check if the event has been completed.

Returns:
bool: True if completed, False otherwise.
"""
return self._underlying.query()

def elapsed_time(self, other):
"""Calculate elapsed time between this event and another event.

Args:
other: The other DeviceEvent to compare with

Returns:
float: Elapsed time in milliseconds between this event and the other event

Raises:
RuntimeError: If events are on different devices or not recorded
"""
return self._underlying.elapsed_time(other._underlying)

def wait(self, stream=None):
"""Make a stream wait for this event to complete.

Args:
stream: Stream to make wait for this event. If None, uses current stream.
"""
self._underlying.wait(stream)

@property
def device(self):
"""Get the device where this event was created."""
return infinicore.device._from_infinicore_device(self._underlying.device)

@property
def is_recorded(self):
"""Check if the event has been recorded."""
return self._underlying.is_recorded

def __repr__(self):
return f"DeviceEvent(device={self.device}, recorded={self.is_recorded})"
Loading