Add support for dense CUDA solvers #3
1. Add CUDA initialization and cleanup management to the
ContextImpl object. The ContextImpl is now solely responsible
for managing CUDA-related resources.
2. All CUDA dense solvers now use lazy CUDA initialization
via the ContextImpl object.
Change-Id: Ief456860c72e462367ee997d389c19e2bff50baf
diff --git a/internal/ceres/context_impl.cc b/internal/ceres/context_impl.cc
index 1acf724..ed239a5 100644
--- a/internal/ceres/context_impl.cc
+++ b/internal/ceres/context_impl.cc
@@ -30,11 +30,73 @@
#include "ceres/context_impl.h"
+#include <string>
+
+#ifndef CERES_NO_CUDA
+#include "cuda_runtime.h"
+#include "cublas_v2.h"
+#include "cusolverDn.h"
+#endif // CERES_NO_CUDA
+
namespace ceres {
namespace internal {
ContextImpl::ContextImpl() = default;
+#ifndef CERES_NO_CUDA
+bool ContextImpl::InitCUDA(std::string* message) {
+ if (cuda_initialized_) {
+ return true;
+ }
+ if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
+ *message = "cuBLAS::cublasCreate failed.";
+ cublas_handle_ = nullptr;
+ return false;
+ }
+ if (cusolverDnCreate(&cusolver_handle_) != CUSOLVER_STATUS_SUCCESS) {
+ *message = "cuSolverDN::cusolverDnCreate failed.";
+ cusolver_handle_ = nullptr;
+ cublasDestroy(cublas_handle_);
+ cublas_handle_ = nullptr;
+ return false;
+ }
+ if (cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking) !=
+ cudaSuccess) {
+ *message = "CUDA::cudaStreamCreateWithFlags failed.";
+ cusolverDnDestroy(cusolver_handle_);
+ cublasDestroy(cublas_handle_);
+ cusolver_handle_ = nullptr;
+ cublas_handle_ = nullptr;
+ stream_ = nullptr;
+ return false;
+ }
+ if (cusolverDnSetStream(cusolver_handle_, stream_) !=
+ CUSOLVER_STATUS_SUCCESS ||
+ cublasSetStream(cublas_handle_, stream_) != CUBLAS_STATUS_SUCCESS) {
+ *message =
+ "cuSolverDN::cusolverDnSetStream or cuBLAS::cublasSetStream failed.";
+ cusolverDnDestroy(cusolver_handle_);
+ cublasDestroy(cublas_handle_);
+ cudaStreamDestroy(stream_);
+ cusolver_handle_ = nullptr;
+ cublas_handle_ = nullptr;
+ stream_ = nullptr;
+ return false;
+ }
+ cuda_initialized_ = true;
+ return true;
+}
+#endif // CERES_NO_CUDA
+
+ContextImpl::~ContextImpl() {
+#ifndef CERES_NO_CUDA
+ if (cuda_initialized_) {
+ cusolverDnDestroy(cusolver_handle_);
+ cublasDestroy(cublas_handle_);
+ cudaStreamDestroy(stream_);
+ }
+#endif // CERES_NO_CUDA
+}
void ContextImpl::EnsureMinimumThreads(int num_threads) {
#ifdef CERES_USE_CXX_THREADS
thread_pool.Resize(num_threads);