diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index e62df3e21f..fe5ba32496 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -309,17 +309,12 @@ struct compute_smem_footprint : dispatch_functor { int run(predict_params); }; -template -__attribute__((visibility("hidden"))) __global__ void infer_k(storage_type forest, - predict_params params); - // infer() calls the inference kernel with the parameters on the stream template void infer(storage_type forest, predict_params params, cudaStream_t stream); +template +void infer_shared_mem_size(predict_params params, int max_shm); + } // namespace fil } // namespace ML diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index e0d2f8baaf..69b0320e1e 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -349,26 +349,6 @@ struct forest { cat_sets_device_owner cat_sets_; }; -template -struct opt_into_arch_dependent_shmem : dispatch_functor { - const int max_shm; - opt_into_arch_dependent_shmem(int max_shm_) : max_shm(max_shm_) {} - - template > - void run(predict_params p) - { - auto kernel = infer_k; - // p.shm_sz might be > max_shm or < MAX_SHM_STD, but we should not check for either, because - // we don't run on both proba_ssp_ and class_ssp_ (only class_ssp_). This should be quick. - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shm)); - } -}; - template struct dense_forest> : forest { using node_t = dense_node; @@ -427,8 +407,9 @@ struct dense_forest> : forest { h.get_stream())); // predict_proba is a runtime parameter, and opt-in is unconditional - dispatch_on_fil_template_params(opt_into_arch_dependent_shmem>(this->max_shm_), - static_cast(this->class_ssp_)); + fil::infer_shared_mem_size>(static_cast(this->class_ssp_), + this->max_shm_); + // copy must be finished before freeing the host data h.sync_stream(); h_nodes_.clear(); @@ -491,8 +472,8 @@ struct sparse_forest : forest { nodes_.data(), nodes, sizeof(node_t) * num_nodes_, cudaMemcpyHostToDevice, h.get_stream())); // predict_proba is a runtime parameter, and opt-in is unconditional - dispatch_on_fil_template_params(opt_into_arch_dependent_shmem>(this->max_shm_), - static_cast(this->class_ssp_)); + fil::infer_shared_mem_size>(static_cast(this->class_ssp_), + this->max_shm_); } virtual void infer(predict_params params, cudaStream_t stream) override diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index 574a0a37e3..c3bdd1b810 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -908,12 +908,38 @@ struct infer_k_storage_template : dispatch_functor { } }; +template +struct opt_into_arch_dependent_shmem : dispatch_functor { + const int max_shm; + opt_into_arch_dependent_shmem(int max_shm_) : max_shm(max_shm_) {} + + template > + void run(predict_params p) + { + auto kernel = infer_k; + // p.shm_sz might be > max_shm or < MAX_SHM_STD, but we should not check for either, because + // we don't run on both proba_ssp_ and class_ssp_ (only class_ssp_). This should be quick. + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shm)); + } +}; + template void infer(storage_type forest, predict_params params, cudaStream_t stream) { dispatch_on_fil_template_params(infer_k_storage_template(forest, stream), params); } +template +void infer_shared_mem_size(predict_params params, int max_shm) +{ + dispatch_on_fil_template_params(opt_into_arch_dependent_shmem(max_shm), params); +} + template void infer(dense_storage_f32 forest, predict_params params, cudaStream_t stream); @@ -930,5 +956,11 @@ template void infer(sparse_storage8 forest, predict_params params, cudaStream_t stream); +template void infer_shared_mem_size(predict_params params, int max_shm); +template void infer_shared_mem_size(predict_params params, int max_shm); +template void infer_shared_mem_size(predict_params params, int max_shm); +template void infer_shared_mem_size(predict_params params, int max_shm); +template void infer_shared_mem_size(predict_params params, int max_shm); + } // namespace fil } // namespace ML