Fix checks for CUDA memory pools support
Change-Id: Icc07625fc0e586e8798da48fa5edfde59487d702
diff --git a/internal/ceres/context_impl.cc b/internal/ceres/context_impl.cc
index c085644..2b9d9cc 100644
--- a/internal/ceres/context_impl.cc
+++ b/internal/ceres/context_impl.cc
@@ -102,7 +102,9 @@
gpu_device_properties_.maxGridSize[1],
gpu_device_properties_.maxGridSize[2],
gpu_device_properties_.multiProcessorCount,
- gpu_device_properties_.memoryPoolsSupported ? "Yes" : "No");
+ // In CUDA 12.0.0+ cudaDeviceProp has field memoryPoolsSupported, but it
+ // is not available in older versions
+ is_cuda_memory_pools_supported_ ? "Yes" : "No");
}
size_t ContextImpl::GpuMemoryAvailable() const {
@@ -123,6 +125,14 @@
CHECK_EQ(
cudaGetDeviceProperties(&gpu_device_properties_, gpu_device_id_in_use_),
cudaSuccess);
+#if CUDART_VERSION >= 11020
+ int is_cuda_memory_pools_supported;
+ CHECK_EQ(cudaDeviceGetAttribute(&is_cuda_memory_pools_supported,
+ cudaDevAttrMemoryPoolsSupported,
+ gpu_device_id_in_use_),
+ cudaSuccess);
+ is_cuda_memory_pools_supported_ = is_cuda_memory_pools_supported == 1;
+#endif
VLOG(3) << "\n" << CudaConfigAsString();
EventLogger event_logger("InitCuda");
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
diff --git a/internal/ceres/context_impl.h b/internal/ceres/context_impl.h
index 508fd06..46692e6 100644
--- a/internal/ceres/context_impl.h
+++ b/internal/ceres/context_impl.h
@@ -134,6 +134,7 @@
bool is_cuda_initialized_ = false;
int gpu_device_id_in_use_ = -1;
cudaDeviceProp gpu_device_properties_;
+ bool is_cuda_memory_pools_supported_ = false;
int cuda_version_major_ = 0;
int cuda_version_minor_ = 0;
#endif // CERES_NO_CUDA
diff --git a/internal/ceres/cuda_block_sparse_crs_view.cc b/internal/ceres/cuda_block_sparse_crs_view.cc
index c370d22..7564d52 100644
--- a/internal/ceres/cuda_block_sparse_crs_view.cc
+++ b/internal/ceres/cuda_block_sparse_crs_view.cc
@@ -52,7 +52,7 @@
rows.data(),
cols.data(),
context->DefaultStream(),
- context);
+ context->is_cuda_memory_pools_supported_);
is_crs_compatible_ = block_structure_->IsCrsCompatible();
// if matrix is crs-compatible - we can drop block-structure and don't need
// streamed_buffer_
diff --git a/internal/ceres/cuda_kernels_bsm_to_crs.cu.cc b/internal/ceres/cuda_kernels_bsm_to_crs.cu.cc
index 05b52e7..b9ca4cd 100644
--- a/internal/ceres/cuda_kernels_bsm_to_crs.cu.cc
+++ b/internal/ceres/cuda_kernels_bsm_to_crs.cu.cc
@@ -55,11 +55,11 @@
cudaStream_t stream,
bool memory_pools_supported) {
void* data = nullptr;
- // Stream-ordered alloaction API is available since CUDA 11.4, but might be
+ // Stream-ordered alloaction API is available since CUDA 11.2, but might be
// not implemented by particular device
-#if CUDART_VERSION < 11040
+#if CUDART_VERSION < 11020
#warning \
- "Stream-ordered allocations are unavailable, consider updating CUDA toolkit to version 11.4+"
+ "Stream-ordered allocations are unavailable, consider updating CUDA toolkit to version 11.2+"
cudaMalloc(&data, size);
#else
if (memory_pools_supported) {
@@ -72,11 +72,11 @@
}
void CudaFree(void* data, cudaStream_t stream, bool memory_pools_supported) {
- // Stream-ordered alloaction API is available since CUDA 11.4, but might be
+ // Stream-ordered alloaction API is available since CUDA 11.2, but might be
// not implemented by particular device
-#if CUDART_VERSION < 11040
+#if CUDART_VERSION < 11020
#warning \
- "Stream-ordered allocations are unavailable, consider updating CUDA toolkit to version 11.4+"
+ "Stream-ordered allocations are unavailable, consider updating CUDA toolkit to version 11.2+"
cudaSuccess, cudaFree(data);
#else
if (memory_pools_supported) {
diff --git a/internal/ceres/cuda_partitioned_block_sparse_crs_view.cc b/internal/ceres/cuda_partitioned_block_sparse_crs_view.cc
index 550e6d0..c0c1dc8 100644
--- a/internal/ceres/cuda_partitioned_block_sparse_crs_view.cc
+++ b/internal/ceres/cuda_partitioned_block_sparse_crs_view.cc
@@ -80,7 +80,7 @@
rows_f.data(),
cols_f.data(),
context->DefaultStream(),
- context);
+ context->is_cuda_memory_pools_supported_);
f_is_crs_compatible_ = block_structure_->IsCrsCompatible();
if (f_is_crs_compatible_) {
block_structure_ = nullptr;