29
29
#define UCT_CUDA_DEV_NAME_MAX_LEN 64
30
30
#define UCT_CUDA_MAX_DEVICES 32
31
31
32
+ #define UCT_CUDA_VERSION_VMM 12030 /* for VMM: cuCtxSetFlags() >= cuda 12.1 */
33
+ #define UCT_CUDA_MAJOR (_version ) ((_version) / 1000)
34
+ #define UCT_CUDA_MINOR (_version ) (((_version) % 1000) / 10)
35
+
32
36
33
37
static const char * uct_cuda_pref_loc [] = {
34
38
[UCT_CUDA_PREF_LOC_CPU ] = "cpu" ,
@@ -515,22 +519,27 @@ static size_t uct_cuda_copy_md_get_total_device_mem(CUdevice cuda_device)
515
519
static void
516
520
uct_cuda_copy_sync_memops (uct_cuda_copy_md_t * md , const void * address )
517
521
{
522
+ unsigned value = 1 ;
523
+
518
524
#if HAVE_CUDA_FABRIC
519
525
ucs_status_t status ;
520
- if (!md -> sync_memops_set ) {
521
- /* Synchronize future DMA operations for all memory types */
522
- status = UCT_CUDADRV_FUNC_LOG_WARN (cuCtxSetFlags (CU_CTX_SYNC_MEMOPS ));
523
- if (status == UCS_OK ) {
524
- md -> sync_memops_set = 1 ;
526
+ if (md -> config .cuda_ctx_set_flags ) {
527
+ if (!md -> sync_memops_set ) {
528
+ /* Synchronize future DMA operations for all memory types */
529
+ status = UCT_CUDADRV_FUNC_LOG_WARN (cuCtxSetFlags (CU_CTX_SYNC_MEMOPS ));
530
+ if (status == UCS_OK ) {
531
+ md -> sync_memops_set = 1 ;
532
+ }
525
533
}
534
+
535
+ return ;
526
536
}
527
- #else
528
- unsigned value = 1 ;
537
+ #endif
538
+
529
539
/* Synchronize for DMA for legacy memory types*/
530
540
UCT_CUDADRV_FUNC_LOG_WARN (
531
541
cuPointerSetAttribute (& value , CU_POINTER_ATTRIBUTE_SYNC_MEMOPS ,
532
542
(CUdeviceptr )address ));
533
- #endif
534
543
}
535
544
536
545
static ucs_status_t
@@ -830,7 +839,7 @@ uct_cuda_copy_md_open(uct_component_t *component, const char *md_name,
830
839
uct_cuda_copy_md_config_t * config = ucs_derived_of (md_config ,
831
840
uct_cuda_copy_md_config_t );
832
841
uct_cuda_copy_md_t * md ;
833
- int dmabuf_supported ;
842
+ int dmabuf_supported , version ;
834
843
ucs_status_t status ;
835
844
836
845
md = ucs_malloc (sizeof (uct_cuda_copy_md_t ), "uct_cuda_copy_md_t" );
@@ -840,15 +849,32 @@ uct_cuda_copy_md_open(uct_component_t *component, const char *md_name,
840
849
goto err ;
841
850
}
842
851
843
- md -> super .ops = & md_ops ;
844
- md -> super .component = & uct_cuda_copy_component ;
845
- md -> config .alloc_whole_reg = config -> alloc_whole_reg ;
846
- md -> config .max_reg_ratio = config -> max_reg_ratio ;
847
- md -> config .pref_loc = config -> pref_loc ;
848
- md -> config .enable_fabric = config -> enable_fabric ;
849
- md -> config .dmabuf_supported = 0 ;
850
- md -> sync_memops_set = 0 ;
851
- md -> granularity = SIZE_MAX ;
852
+ md -> super .ops = & md_ops ;
853
+ md -> super .component = & uct_cuda_copy_component ;
854
+ md -> config .alloc_whole_reg = config -> alloc_whole_reg ;
855
+ md -> config .max_reg_ratio = config -> max_reg_ratio ;
856
+ md -> config .pref_loc = config -> pref_loc ;
857
+ md -> config .enable_fabric = config -> enable_fabric ;
858
+ md -> config .dmabuf_supported = 0 ;
859
+ md -> config .cuda_ctx_set_flags = 1 ;
860
+ md -> sync_memops_set = 0 ;
861
+ md -> granularity = SIZE_MAX ;
862
+
863
+ #if HAVE_CUDA_FABRIC
864
+ if ((cuDriverGetVersion (& version ) == CUDA_SUCCESS ) &&
865
+ (version < UCT_CUDA_VERSION_VMM )) {
866
+ if (md -> config .enable_fabric != UCS_NO ) {
867
+ ucs_warn ("disabled fabric memory allocations as cuda driver "
868
+ "library %d.%d < %d.%d" ,
869
+ UCT_CUDA_MAJOR (version ), UCT_CUDA_MINOR (version ),
870
+ UCT_CUDA_MAJOR (UCT_CUDA_VERSION_VMM ),
871
+ UCT_CUDA_MINOR (UCT_CUDA_VERSION_VMM ));
872
+ }
873
+
874
+ md -> config .enable_fabric = UCS_NO ;
875
+ md -> config .cuda_ctx_set_flags = 0 ;
876
+ }
877
+ #endif
852
878
853
879
if ((config -> cuda_async_mem_type != UCS_MEMORY_TYPE_CUDA ) &&
854
880
(config -> cuda_async_mem_type != UCS_MEMORY_TYPE_CUDA_MANAGED )) {
0 commit comments