diff --git a/projects/hip-tests/catch/unit/graph/hipGraph.cc b/projects/hip-tests/catch/unit/graph/hipGraph.cc index e1c12895846..c6e5c47eadd 100644 --- a/projects/hip-tests/catch/unit/graph/hipGraph.cc +++ b/projects/hip-tests/catch/unit/graph/hipGraph.cc @@ -68,12 +68,13 @@ static void init_input(float* a, size_t size) { /** * Regular procedure of using stream with async api calls */ -static void hipWithoutGraphs(float* inputVec_h, float* inputVec_d, - double* outputVec_d, double* result_d, size_t inputSize, size_t numOfBlocks) { +static void hipWithoutGraphs(float* inputVec_h, float* inputVec_d, double* outputVec_d, + double* result_d, size_t inputSize, size_t numOfBlocks, + unsigned int streamFlags) { hipStream_t stream1, stream2, stream3; hipEvent_t forkStreamEvent, memsetEvent1, memsetEvent2; double result_h = 0.0; - HIP_CHECK(hipStreamCreate(&stream1)); + HIP_CHECK(hipStreamCreateWithFlags(&stream1, streamFlags)); HIP_CHECK(hipStreamCreate(&stream2)); HIP_CHECK(hipStreamCreate(&stream3)); HIP_CHECK(hipEventCreate(&forkStreamEvent)); @@ -123,14 +124,14 @@ static void hipWithoutGraphs(float* inputVec_h, float* inputVec_d, * Capturing sequence of operations in stream and launching graph * with the nodes automatically added. */ -static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d, - double* outputVec_d, double* result_d, - size_t inputSize, size_t numOfBlocks) { +static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d, double* outputVec_d, + double* result_d, size_t inputSize, size_t numOfBlocks, + unsigned int captureStreamFlags) { hipStream_t stream1, stream2, stream3, streamForGraph; hipEvent_t forkStreamEvent, memsetEvent1, memsetEvent2; hipGraph_t graph; double result_h = 0.0; - HIP_CHECK(hipStreamCreate(&stream1)); + HIP_CHECK(hipStreamCreateWithFlags(&stream1, captureStreamFlags)); HIP_CHECK(hipStreamCreate(&stream2)); HIP_CHECK(hipStreamCreate(&stream3)); HIP_CHECK(hipStreamCreate(&streamForGraph)); @@ -159,12 +160,40 @@ static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d, HIP_CHECK(hipMemcpyAsync(&result_h, result_d, sizeof(double), hipMemcpyDefault, stream1)); HIP_CHECK(hipStreamEndCapture(stream1, &graph)); - hipGraphNode_t* nodes{nullptr}; size_t numNodes = 0; - HIP_CHECK(hipGraphGetNodes(graph, nodes, &numNodes)); + HIP_CHECK(hipGraphGetNodes(graph, nullptr, &numNodes)); + + std::vector nodes(numNodes); + HIP_CHECK(hipGraphGetNodes(graph, nodes.data(), &numNodes)); INFO("Num of nodes in the graph created using stream capture API" << numNodes); - HIP_CHECK(hipGraphGetRootNodes(graph, nodes, &numNodes)); + REQUIRE(numNodes == 6); + for (size_t i = 0; i < nodes.size(); ++i) { + hipGraphNodeType node_type; + HIP_CHECK(hipGraphNodeGetType(nodes[i], &node_type)); + switch (i) { + case (0): + REQUIRE(node_type == hipGraphNodeTypeMemcpy); + break; + case (1): + REQUIRE(node_type == hipGraphNodeTypeMemset); + break; + case (2): + REQUIRE(node_type == hipGraphNodeTypeMemset); + break; + case (3): + REQUIRE(node_type == hipGraphNodeTypeKernel); + break; + case (4): + REQUIRE(node_type == hipGraphNodeTypeKernel); + break; + case (5): + REQUIRE(node_type == hipGraphNodeTypeMemcpy); + break; + } + } + + HIP_CHECK(hipGraphGetRootNodes(graph, nullptr, &numNodes)); INFO("Num of root nodes in the graph created using" " stream capture API" << numNodes); hipGraphExec_t graphExec; @@ -203,15 +232,15 @@ static void hipGraphsUsingStreamCapture(float* inputVec_h, float* inputVec_d, /** * Manual procedure of adding nodes to graphs and mapping dependencies. */ -static void hipGraphsManual(float* inputVec_h, float* inputVec_d, - double* outputVec_d, double* result_d, size_t inputSize, - size_t numOfBlocks) { +static void hipGraphsManual(float* inputVec_h, float* inputVec_d, double* outputVec_d, + double* result_d, size_t inputSize, size_t numOfBlocks, + unsigned int captureStreamFlags) { hipStream_t streamForGraph; hipGraph_t graph; std::vector nodeDependencies; hipGraphNode_t memcpyNode, kernelNode, memsetNode; double result_h = 0.0; - HIP_CHECK(hipStreamCreate(&streamForGraph)); + HIP_CHECK(hipStreamCreateWithFlags(&streamForGraph, captureStreamFlags)); auto start = std::chrono::high_resolution_clock::now(); hipKernelNodeParams kernelNodeParams{}; hipMemsetParams memsetParams{}; @@ -269,12 +298,39 @@ static void hipGraphsManual(float* inputVec_h, float* inputVec_d, nodeDependencies.clear(); nodeDependencies.push_back(memcpyNode); hipGraphExec_t graphExec; - hipGraphNode_t* nodes{nullptr}; - size_t numNodes{}; - HIP_CHECK(hipGraphGetNodes(graph, nodes, &numNodes)); - INFO("Num of nodes in the graph created using hipGraphs Manual" - << numNodes); - HIP_CHECK(hipGraphGetRootNodes(graph, nodes, &numNodes)); + size_t numNodes = 0; + HIP_CHECK(hipGraphGetNodes(graph, nullptr, &numNodes)); + + std::vector nodes(numNodes); + HIP_CHECK(hipGraphGetNodes(graph, nodes.data(), &numNodes)); + INFO("Num of nodes in the graph created using stream capture API" << numNodes); + REQUIRE(numNodes == 6); + for (size_t i = 0; i < nodes.size(); ++i) { + hipGraphNodeType node_type; + HIP_CHECK(hipGraphNodeGetType(nodes[i], &node_type)); + switch (i) { + case (0): + REQUIRE(node_type == hipGraphNodeTypeMemcpy); + break; + case (1): + REQUIRE(node_type == hipGraphNodeTypeMemset); + break; + case (2): + REQUIRE(node_type == hipGraphNodeTypeMemset); + break; + case (3): + REQUIRE(node_type == hipGraphNodeTypeKernel); + break; + case (4): + REQUIRE(node_type == hipGraphNodeTypeKernel); + break; + case (5): + REQUIRE(node_type == hipGraphNodeTypeMemcpy); + break; + } + } + + HIP_CHECK(hipGraphGetRootNodes(graph, nullptr, &numNodes)); INFO("Num of root nodes in the graph created using" " hipGraphs Manual" << numNodes); HIP_CHECK(hipGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0)); @@ -315,6 +371,7 @@ TEST_CASE("Unit_hipGraph_BasicFunctional") { constexpr size_t maxBlocks = 512; float *inputVec_d{nullptr}, *inputVec_h{nullptr}; double *outputVec_d{nullptr}, *result_d{nullptr}; + unsigned int streamFlags = GENERATE(hipStreamDefault, hipStreamNonBlocking); INFO("Elements : " << size << " ThreadsPerBlock : " << THREADS_PER_BLOCK); INFO("Graph Launch iterations = " << GRAPH_LAUNCH_ITERATIONS); @@ -328,18 +385,16 @@ TEST_CASE("Unit_hipGraph_BasicFunctional") { init_input(inputVec_h, size); SECTION("Execution Without HIPGraphs") { - hipWithoutGraphs(inputVec_h, inputVec_d, outputVec_d, - result_d, size, maxBlocks); + hipWithoutGraphs(inputVec_h, inputVec_d, outputVec_d, result_d, size, maxBlocks, streamFlags); } SECTION("Manual HIPGraph") { - hipGraphsManual(inputVec_h, inputVec_d, outputVec_d, - result_d, size, maxBlocks); + hipGraphsManual(inputVec_h, inputVec_d, outputVec_d, result_d, size, maxBlocks, streamFlags); } SECTION("HIPGraphs Using StreamCapture") { - hipGraphsUsingStreamCapture(inputVec_h, inputVec_d, - outputVec_d, result_d, size, maxBlocks); + hipGraphsUsingStreamCapture(inputVec_h, inputVec_d, outputVec_d, result_d, size, maxBlocks, + streamFlags); } HIP_CHECK(hipFree(inputVec_d));