Add support for dense CUDA solvers #3

1. Add CUDA initialization and cleanup management to the
   ContextImpl object. The ContextImpl is now solely responsible
   for managing CUDA-related resources.
2. All CUDA dense solvers now use lazy CUDA initialization
   via the ContextImpl object.

Change-Id: Ief456860c72e462367ee997d389c19e2bff50baf
diff --git a/internal/ceres/context_impl.cc b/internal/ceres/context_impl.cc
index 1acf724..ed239a5 100644
--- a/internal/ceres/context_impl.cc
+++ b/internal/ceres/context_impl.cc
@@ -30,11 +30,73 @@
 
 #include "ceres/context_impl.h"
 
+#include <string>
+
+#ifndef CERES_NO_CUDA
+#include "cuda_runtime.h"
+#include "cublas_v2.h"
+#include "cusolverDn.h"
+#endif  // CERES_NO_CUDA
+
 namespace ceres {
 namespace internal {
 
 ContextImpl::ContextImpl() = default;
 
+#ifndef CERES_NO_CUDA
+bool ContextImpl::InitCUDA(std::string* message) {
+  if (cuda_initialized_) {
+    return true;
+  }
+  if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
+    *message = "cuBLAS::cublasCreate failed.";
+    cublas_handle_ = nullptr;
+    return false;
+  }
+  if (cusolverDnCreate(&cusolver_handle_) != CUSOLVER_STATUS_SUCCESS) {
+    *message = "cuSolverDN::cusolverDnCreate failed.";
+    cusolver_handle_ = nullptr;
+    cublasDestroy(cublas_handle_);
+    cublas_handle_ = nullptr;
+    return false;
+  }
+  if (cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking) !=
+      cudaSuccess) {
+    *message = "CUDA::cudaStreamCreateWithFlags failed.";
+    cusolverDnDestroy(cusolver_handle_);
+    cublasDestroy(cublas_handle_);
+    cusolver_handle_ = nullptr;
+    cublas_handle_ = nullptr;
+    stream_ = nullptr;
+    return false;
+  }
+  if (cusolverDnSetStream(cusolver_handle_, stream_) !=
+      CUSOLVER_STATUS_SUCCESS ||
+      cublasSetStream(cublas_handle_, stream_) != CUBLAS_STATUS_SUCCESS) {
+    *message =
+        "cuSolverDN::cusolverDnSetStream or cuBLAS::cublasSetStream failed.";
+    cusolverDnDestroy(cusolver_handle_);
+    cublasDestroy(cublas_handle_);
+    cudaStreamDestroy(stream_);
+    cusolver_handle_ = nullptr;
+    cublas_handle_ = nullptr;
+    stream_ = nullptr;
+    return false;
+  }
+  cuda_initialized_ = true;
+  return true;
+}
+#endif  // CERES_NO_CUDA
+
+ContextImpl::~ContextImpl() {
+#ifndef CERES_NO_CUDA
+  if (cuda_initialized_) {
+    cusolverDnDestroy(cusolver_handle_);
+    cublasDestroy(cublas_handle_);
+    cudaStreamDestroy(stream_);
+  }
+#endif  // CERES_NO_CUDA
+}
 void ContextImpl::EnsureMinimumThreads(int num_threads) {
 #ifdef CERES_USE_CXX_THREADS
   thread_pool.Resize(num_threads);
diff --git a/internal/ceres/context_impl.h b/internal/ceres/context_impl.h
index 7d1e6d3..1944549 100644
--- a/internal/ceres/context_impl.h
+++ b/internal/ceres/context_impl.h
@@ -40,6 +40,12 @@
 #include "ceres/internal/disable_warnings.h"
 #include "ceres/internal/export.h"
 
+#ifndef CERES_NO_CUDA
+#include "cuda_runtime.h"
+#include "cublas_v2.h"
+#include "cusolverDn.h"
+#endif  // CERES_NO_CUDA
+
 #ifdef CERES_USE_CXX_THREADS
 #include "ceres/thread_pool.h"
 #endif  // CERES_USE_CXX_THREADS
@@ -50,6 +56,7 @@
 class CERES_NO_EXPORT ContextImpl : public Context {
  public:
   ContextImpl();
+  ~ContextImpl() override;
   ContextImpl(const ContextImpl&) = delete;
   void operator=(const ContextImpl&) = delete;
 
@@ -62,6 +69,23 @@
 #ifdef CERES_USE_CXX_THREADS
   ThreadPool thread_pool;
 #endif  // CERES_USE_CXX_THREADS
+
+#ifndef CERES_NO_CUDA
+  // Initializes the cuSolverDN context, creates an asynchronous stream, and
+  // associates the stream with cuSolverDN. Returns true iff initialization was
+  // successful, else it returns false and a human-readable error message is
+  // returned.
+  bool InitCUDA(std::string* message);
+
+  // Handle to the cuSOLVER context.
+  cusolverDnHandle_t cusolver_handle_ = nullptr;
+  // Handle to cuBLAS context.
+  cublasHandle_t cublas_handle_ = nullptr;
+  // CUDA device stream.
+  cudaStream_t stream_ = nullptr;
+  // Indicates whether all the CUDA resources have been initialized.
+  bool cuda_initialized_ = false;
+#endif  // CERES_NO_CUDA
 };
 
 }  // namespace internal
diff --git a/internal/ceres/cuda_dense_cholesky_test.cc b/internal/ceres/cuda_dense_cholesky_test.cc
index 6c4dcc3..cca97d8 100644
--- a/internal/ceres/cuda_dense_cholesky_test.cc
+++ b/internal/ceres/cuda_dense_cholesky_test.cc
@@ -43,6 +43,8 @@
 
 TEST(CUDADenseCholesky, InvalidOptionOnCreate) {
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   auto dense_cuda_solver = CUDADenseCholesky::Create(options);
   EXPECT_EQ(dense_cuda_solver, nullptr);
 }
@@ -56,6 +58,8 @@
         0,   0,   0, 1;
   const Eigen::Vector4d b = Eigen::Vector4d::Ones();
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseCholesky::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
@@ -80,6 +84,8 @@
         0, 0, 0;
   const Eigen::Vector3d b = Eigen::Vector3d::Ones();
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseCholesky::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
@@ -97,6 +103,8 @@
         0, 0, -1;
   const Eigen::Vector3d b = Eigen::Vector3d::Ones();
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseCholesky::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
@@ -110,6 +118,8 @@
 TEST(CUDADenseCholesky, MustFactorizeBeforeSolve) {
   const Eigen::Vector3d b = Eigen::Vector3d::Ones();
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseCholesky::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
diff --git a/internal/ceres/cuda_dense_qr_test.cc b/internal/ceres/cuda_dense_qr_test.cc
index 15ba00a..6a64298 100644
--- a/internal/ceres/cuda_dense_qr_test.cc
+++ b/internal/ceres/cuda_dense_qr_test.cc
@@ -56,6 +56,8 @@
         0,   0,   0, 1;
   const Eigen::Vector4d b = Eigen::Vector4d::Ones();
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseQR::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
@@ -85,6 +87,8 @@
         0,   0;
   const std::vector<double> b(4, 1.0);
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseQR::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
@@ -107,6 +111,8 @@
 TEST(CUDADenseQR, MustFactorizeBeforeSolve) {
   const Eigen::Vector3d b = Eigen::Vector3d::Ones();
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = CUDA;
   auto dense_cuda_solver = CUDADenseQR::Create(options);
   ASSERT_NE(dense_cuda_solver, nullptr);
@@ -123,6 +129,8 @@
   using SolutionType = Eigen::Matrix<double, Eigen::Dynamic, 1>;
 
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = ceres::CUDA;
   std::unique_ptr<DenseQR> dense_qr = CUDADenseQR::Create(options);
 
diff --git a/internal/ceres/dense_cholesky.cc b/internal/ceres/dense_cholesky.cc
index 2d8c2da..f426df5 100644
--- a/internal/ceres/dense_cholesky.cc
+++ b/internal/ceres/dense_cholesky.cc
@@ -36,6 +36,7 @@
 #include <vector>
 
 #ifndef CERES_NO_CUDA
+#include "ceres/context_impl.h"
 #include "cuda_runtime.h"
 #include "cusolverDn.h"
 #endif  // CERES_NO_CUDA
@@ -193,36 +194,17 @@
 
 #ifndef CERES_NO_CUDA
 
-bool CUDADenseCholesky::Init(std::string* message) {
-  if (cusolverDnCreate(&cusolver_handle_) != CUSOLVER_STATUS_SUCCESS) {
-    *message = "cuSolverDN::cusolverDnCreate failed.";
+bool CUDADenseCholesky::Init(ContextImpl* context, std::string* message) {
+  if (!context->InitCUDA(message)) {
     return false;
   }
-  if (cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking) !=
-      cudaSuccess) {
-    *message = "CUDA::cudaStreamCreateWithFlags failed.";
-    cusolverDnDestroy(cusolver_handle_);
-    return false;
-  }
-  if (cusolverDnSetStream(cusolver_handle_, stream_) !=
-      CUSOLVER_STATUS_SUCCESS) {
-    *message = "cuSolverDN::cusolverDnSetStream failed.";
-    cudaStreamDestroy(stream_);
-    cusolverDnDestroy(cusolver_handle_);
-    return false;
-  }
+  cusolver_handle_ = context->cusolver_handle_;
+  stream_ = context->stream_;
   error_.Reserve(1);
   *message = "CUDADenseCholesky::Init Success.";
   return true;
 }
 
-CUDADenseCholesky::~CUDADenseCholesky() {
-  if (cusolver_handle_ != nullptr) {
-    CHECK_EQ(cusolverDnDestroy(cusolver_handle_), CUSOLVER_STATUS_SUCCESS);
-    CHECK_EQ(cudaStreamDestroy(stream_), cudaSuccess);
-  }
-}
-
 LinearSolverTerminationType CUDADenseCholesky::Factorize(
     int num_cols, double* lhs, std::string* message) {
   factorize_result_ = LinearSolverTerminationType::LINEAR_SOLVER_FATAL_ERROR;
@@ -326,7 +308,7 @@
   auto cuda_dense_cholesky =
       std::unique_ptr<CUDADenseCholesky>(new CUDADenseCholesky());
   std::string cuda_error;
-  if (cuda_dense_cholesky->Init(&cuda_error)) {
+  if (cuda_dense_cholesky->Init(options.context, &cuda_error)) {
     return cuda_dense_cholesky;
   }
   // Initialization failed, destroy the object (done automatically) and return a
diff --git a/internal/ceres/dense_cholesky.h b/internal/ceres/dense_cholesky.h
index 49d780c..b40e69a 100644
--- a/internal/ceres/dense_cholesky.h
+++ b/internal/ceres/dense_cholesky.h
@@ -44,6 +44,7 @@
 #include "ceres/linear_solver.h"
 #include "glog/logging.h"
 #ifndef CERES_NO_CUDA
+#include "ceres/context_impl.h"
 #include "cuda_runtime.h"
 #include "cusolverDn.h"
 #endif  // CERES_NO_CUDA
@@ -140,7 +141,6 @@
  public:
   static std::unique_ptr<CUDADenseCholesky> Create(
       const LinearSolver::Options& options);
-  ~CUDADenseCholesky() override;
   CUDADenseCholesky(const CUDADenseCholesky&) = delete;
   CUDADenseCholesky& operator=(const CUDADenseCholesky&) = delete;
   LinearSolverTerminationType Factorize(int num_cols,
@@ -152,11 +152,10 @@
 
  private:
   CUDADenseCholesky() = default;
-  // Initializes the cuSolverDN context, creates an asynchronous stream, and
-  // associates the stream with cuSolverDN. Returns true iff initialization was
-  // successful, else it returns false and a human-readable error message is
-  // returned.
-  bool Init(std::string* message);
+  // Picks up the cuSolverDN and cuStream handles from the context. If
+  // the context is unable to initialize CUDA, returns false with a
+  // human-readable message indicating the reason.
+  bool Init(ContextImpl* context, std::string* message);
 
   // Handle to the cuSOLVER context.
   cusolverDnHandle_t cusolver_handle_ = nullptr;
diff --git a/internal/ceres/dense_cholesky_test.cc b/internal/ceres/dense_cholesky_test.cc
index eb1c336..034206a 100644
--- a/internal/ceres/dense_cholesky_test.cc
+++ b/internal/ceres/dense_cholesky_test.cc
@@ -65,6 +65,8 @@
   using VectorType = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
 
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = GetParam();
   std::unique_ptr<DenseCholesky> dense_cholesky =
       DenseCholesky::Create(options);
diff --git a/internal/ceres/dense_qr.cc b/internal/ceres/dense_qr.cc
index 77f04ad..ad9b64e 100644
--- a/internal/ceres/dense_qr.cc
+++ b/internal/ceres/dense_qr.cc
@@ -34,6 +34,7 @@
 #include <memory>
 #include <string>
 #ifndef CERES_NO_CUDA
+#include "ceres/context_impl.h"
 #include "cusolverDn.h"
 #include "cublas_v2.h"
 #endif  // CERES_NO_CUDA
@@ -310,53 +311,18 @@
 
 #ifndef CERES_NO_CUDA
 
-bool CUDADenseQR::Init(std::string* message) {
-  if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
-    *message = "cuBLAS::cublasCreate failed.";
+bool CUDADenseQR::Init(ContextImpl* context, std::string* message) {
+  if (!context->InitCUDA(message)) {
     return false;
   }
-  if (cusolverDnCreate(&cusolver_handle_) != CUSOLVER_STATUS_SUCCESS) {
-    *message = "cuSolverDN::cusolverDnCreate failed.";
-    return false;
-  }
-  if (cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking) !=
-      cudaSuccess) {
-    *message = "CUDA::cudaStreamCreateWithFlags failed.";
-    cusolverDnDestroy(cusolver_handle_);
-    cublasDestroy(cublas_handle_);
-    return false;
-  }
-  if (cusolverDnSetStream(cusolver_handle_, stream_) != CUSOLVER_STATUS_SUCCESS) {
-    *message = "cuSolverDN::cusolverDnSetStream failed.";
-    cusolverDnDestroy(cusolver_handle_);
-    cudaStreamDestroy(stream_);
-    cublasDestroy(cublas_handle_);
-    return false;
-  }
-  if (cublasSetStream(cublas_handle_, stream_) != CUBLAS_STATUS_SUCCESS) {
-    *message = "cuBLAS::cublasSetStream failed.";
-    cusolverDnDestroy(cusolver_handle_);
-    cublasDestroy(cublas_handle_);
-    cudaStreamDestroy(stream_);
-    return false;
-  }
+  cublas_handle_ = context->cublas_handle_;
+  cusolver_handle_ = context->cusolver_handle_;
+  stream_ = context->stream_;
   error_.Reserve(1);
   *message = "CUDADenseQR::Init Success.";
   return true;
 }
 
-CUDADenseQR::~CUDADenseQR() {
-  if (cublas_handle_ != nullptr) {
-    CHECK_EQ(cublasDestroy(cublas_handle_), CUBLAS_STATUS_SUCCESS);
-  }
-  if (cusolver_handle_ != nullptr) {
-    CHECK_EQ(cusolverDnDestroy(cusolver_handle_), CUSOLVER_STATUS_SUCCESS);
-  }
-  if (stream_ != nullptr) {
-    CHECK_EQ(cudaStreamDestroy(stream_), cudaSuccess);
-  }
-}
-
 LinearSolverTerminationType CUDADenseQR::Factorize(
     int num_rows, int num_cols, double* lhs, std::string* message) {
   factorize_result_ = LinearSolverTerminationType::LINEAR_SOLVER_FATAL_ERROR;
@@ -496,7 +462,7 @@
   auto cuda_dense_qr =
       std::unique_ptr<CUDADenseQR>(new CUDADenseQR());
   std::string cuda_error;
-  if (cuda_dense_qr->Init(&cuda_error)) {
+  if (cuda_dense_qr->Init(options.context, &cuda_error)) {
     return cuda_dense_qr;
   }
   // Initialization failed, destroy the object (done automatically) and return a
diff --git a/internal/ceres/dense_qr.h b/internal/ceres/dense_qr.h
index 8bcccf0..d42cf8c 100644
--- a/internal/ceres/dense_qr.h
+++ b/internal/ceres/dense_qr.h
@@ -45,7 +45,9 @@
 #include "ceres/internal/export.h"
 #include "ceres/linear_solver.h"
 #include "glog/logging.h"
+
 #ifndef CERES_NO_CUDA
+#include "ceres/context_impl.h"
 #include "ceres/cuda_buffer.h"
 #include "cuda_runtime.h"
 #include "cublas_v2.h"
@@ -153,7 +155,6 @@
  public:
   static std::unique_ptr<CUDADenseQR> Create(
       const LinearSolver::Options& options);
-  ~CUDADenseQR() override;
   CUDADenseQR(const CUDADenseQR&) = delete;
   CUDADenseQR& operator=(const CUDADenseQR&) = delete;
   LinearSolverTerminationType Factorize(int num_rows,
@@ -166,11 +167,10 @@
 
  private:
   CUDADenseQR();
-  // Initializes the cuSolverDN context, creates an asynchronous stream, and
-  // associates the stream with cuSolverDN. Returns true iff initialization was
-  // successful, else it returns false and a human-readable error message is
-  // returned.
-  bool Init(std::string* message);
+  // Picks up the cuSolverDN, cuBLAS, and cuStream handles from the context. If
+  // the context is unable to initialize CUDA, returns false with a
+  // human-readable message indicating the reason.
+  bool Init(ContextImpl* context,std::string* message);
 
   // Handle to the cuSOLVER context.
   cusolverDnHandle_t cusolver_handle_ = nullptr;
diff --git a/internal/ceres/dense_qr_test.cc b/internal/ceres/dense_qr_test.cc
index d9d307e..f796186 100644
--- a/internal/ceres/dense_qr_test.cc
+++ b/internal/ceres/dense_qr_test.cc
@@ -67,6 +67,8 @@
   using VectorType = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
 
   LinearSolver::Options options;
+  ContextImpl context;
+  options.context = &context;
   options.dense_linear_algebra_library_type = GetParam();
   const double kEpsilon = std::numeric_limits<double>::epsilon() * 1.5e4;
   std::unique_ptr<DenseQR> dense_qr = DenseQR::Create(options);