Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 81 additions & 26 deletions projects/hip-tests/catch/unit/graph/hipGraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ static void init_input(float* a, size_t size) {
/**
* Regular procedure of using stream with async api calls
*/
static void hipWithoutGraphs(float* inputVec_h, float* inputVec_d,
double* outputVec_d, double* result_d, size_t inputSize, size_t numOfBlocks) {
static void hipWithoutGraphs(float* inputVec_h, float* inputVec_d, double* outputVec_d,
double* result_d, size_t inputSize, size_t numOfBlocks,
unsigned int streamFlags) {
hipStream_t stream1, stream2, stream3;
hipEvent_t forkStreamEvent, memsetEvent1, memsetEvent2;
double result_h = 0.0;
HIP_CHECK(hipStreamCreate(&stream1));
HIP_CHECK(hipStreamCreateWithFlags(&stream1, streamFlags));
HIP_CHECK(hipStreamCreate(&stream2));
HIP_CHECK(hipStreamCreate(&stream3));
HIP_CHECK(hipEventCreate(&forkStreamEvent));
Expand Down Expand Up @@ -123,14 +124,14 @@ static void hipWithoutGraphs(float* inputVec_h, float* inputVec_d,
* Capturing sequence of operations in stream and launching graph
* with the nodes automatically added.
*/
static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d,
double* outputVec_d, double* result_d,
size_t inputSize, size_t numOfBlocks) {
static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d, double* outputVec_d,
double* result_d, size_t inputSize, size_t numOfBlocks,
unsigned int captureStreamFlags) {
hipStream_t stream1, stream2, stream3, streamForGraph;
hipEvent_t forkStreamEvent, memsetEvent1, memsetEvent2;
hipGraph_t graph;
double result_h = 0.0;
HIP_CHECK(hipStreamCreate(&stream1));
HIP_CHECK(hipStreamCreateWithFlags(&stream1, captureStreamFlags));
HIP_CHECK(hipStreamCreate(&stream2));
HIP_CHECK(hipStreamCreate(&stream3));
HIP_CHECK(hipStreamCreate(&streamForGraph));
Expand Down Expand Up @@ -159,12 +160,40 @@ static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d,
HIP_CHECK(hipMemcpyAsync(&result_h, result_d, sizeof(double),
hipMemcpyDefault, stream1));
HIP_CHECK(hipStreamEndCapture(stream1, &graph));
hipGraphNode_t* nodes{nullptr};
size_t numNodes = 0;
HIP_CHECK(hipGraphGetNodes(graph, nodes, &numNodes));
HIP_CHECK(hipGraphGetNodes(graph, nullptr, &numNodes));

std::vector<hipGraphNode_t> nodes(numNodes);
HIP_CHECK(hipGraphGetNodes(graph, nodes.data(), &numNodes));
INFO("Num of nodes in the graph created using stream capture API"
<< numNodes);
HIP_CHECK(hipGraphGetRootNodes(graph, nodes, &numNodes));
REQUIRE(numNodes == 6);
for (size_t i = 0; i < nodes.size(); ++i) {
hipGraphNodeType node_type;
HIP_CHECK(hipGraphNodeGetType(nodes[i], &node_type));
switch (i) {
case (0):
REQUIRE(node_type == hipGraphNodeTypeMemcpy);
break;
case (1):
REQUIRE(node_type == hipGraphNodeTypeMemset);
break;
case (2):
REQUIRE(node_type == hipGraphNodeTypeMemset);
break;
case (3):
REQUIRE(node_type == hipGraphNodeTypeKernel);
break;
case (4):
REQUIRE(node_type == hipGraphNodeTypeKernel);
break;
case (5):
REQUIRE(node_type == hipGraphNodeTypeMemcpy);
break;
}
}

HIP_CHECK(hipGraphGetRootNodes(graph, nullptr, &numNodes));
INFO("Num of root nodes in the graph created using"
" stream capture API" << numNodes);
hipGraphExec_t graphExec;
Expand Down Expand Up @@ -203,15 +232,15 @@ static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d,
/**
* Manual procedure of adding nodes to graphs and mapping dependencies.
*/
static void hipGraphsManual(float* inputVec_h, float* inputVec_d,
double* outputVec_d, double* result_d, size_t inputSize,
size_t numOfBlocks) {
static void hipGraphsManual(float* inputVec_h, float* inputVec_d, double* outputVec_d,
double* result_d, size_t inputSize, size_t numOfBlocks,
unsigned int captureStreamFlags) {
hipStream_t streamForGraph;
hipGraph_t graph;
std::vector<hipGraphNode_t> nodeDependencies;
hipGraphNode_t memcpyNode, kernelNode, memsetNode;
double result_h = 0.0;
HIP_CHECK(hipStreamCreate(&streamForGraph));
HIP_CHECK(hipStreamCreateWithFlags(&streamForGraph, captureStreamFlags));
auto start = std::chrono::high_resolution_clock::now();
hipKernelNodeParams kernelNodeParams{};
hipMemsetParams memsetParams{};
Expand Down Expand Up @@ -269,12 +298,39 @@ static void hipGraphsManual(float* inputVec_h, float* inputVec_d,
nodeDependencies.clear();
nodeDependencies.push_back(memcpyNode);
hipGraphExec_t graphExec;
hipGraphNode_t* nodes{nullptr};
size_t numNodes{};
HIP_CHECK(hipGraphGetNodes(graph, nodes, &numNodes));
INFO("Num of nodes in the graph created using hipGraphs Manual"
<< numNodes);
HIP_CHECK(hipGraphGetRootNodes(graph, nodes, &numNodes));
size_t numNodes = 0;
HIP_CHECK(hipGraphGetNodes(graph, nullptr, &numNodes));

std::vector<hipGraphNode_t> nodes(numNodes);
HIP_CHECK(hipGraphGetNodes(graph, nodes.data(), &numNodes));
INFO("Num of nodes in the graph created using stream capture API" << numNodes);
REQUIRE(numNodes == 6);
for (size_t i = 0; i < nodes.size(); ++i) {
hipGraphNodeType node_type;
HIP_CHECK(hipGraphNodeGetType(nodes[i], &node_type));
switch (i) {
case (0):
REQUIRE(node_type == hipGraphNodeTypeMemcpy);
break;
case (1):
REQUIRE(node_type == hipGraphNodeTypeMemset);
break;
case (2):
REQUIRE(node_type == hipGraphNodeTypeMemset);
break;
case (3):
REQUIRE(node_type == hipGraphNodeTypeKernel);
break;
case (4):
REQUIRE(node_type == hipGraphNodeTypeKernel);
break;
case (5):
REQUIRE(node_type == hipGraphNodeTypeMemcpy);
break;
}
}

HIP_CHECK(hipGraphGetRootNodes(graph, nullptr, &numNodes));
INFO("Num of root nodes in the graph created using"
" hipGraphs Manual" << numNodes);
HIP_CHECK(hipGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0));
Expand Down Expand Up @@ -315,6 +371,7 @@ TEST_CASE("Unit_hipGraph_BasicFunctional") {
constexpr size_t maxBlocks = 512;
float *inputVec_d{nullptr}, *inputVec_h{nullptr};
double *outputVec_d{nullptr}, *result_d{nullptr};
unsigned int streamFlags = GENERATE(hipStreamDefault, hipStreamNonBlocking);

INFO("Elements : " << size << " ThreadsPerBlock : " << THREADS_PER_BLOCK);
INFO("Graph Launch iterations = " << GRAPH_LAUNCH_ITERATIONS);
Expand All @@ -328,18 +385,16 @@ TEST_CASE("Unit_hipGraph_BasicFunctional") {
init_input(inputVec_h, size);

SECTION("Execution Without HIPGraphs") {
hipWithoutGraphs(inputVec_h, inputVec_d, outputVec_d,
result_d, size, maxBlocks);
hipWithoutGraphs(inputVec_h, inputVec_d, outputVec_d, result_d, size, maxBlocks, streamFlags);
}

SECTION("Manual HIPGraph") {
hipGraphsManual(inputVec_h, inputVec_d, outputVec_d,
result_d, size, maxBlocks);
hipGraphsManual(inputVec_h, inputVec_d, outputVec_d, result_d, size, maxBlocks, streamFlags);
}

SECTION("HIPGraphs Using StreamCapture") {
hipGraphsUsingStreamCapture(inputVec_h, inputVec_d,
outputVec_d, result_d, size, maxBlocks);
hipGraphsUsingStreamCapture(inputVec_h, inputVec_d, outputVec_d, result_d, size, maxBlocks,
streamFlags);
}

HIP_CHECK(hipFree(inputVec_d));
Expand Down