| // Ceres Solver - A fast non-linear least squares minimizer | 
 | // Copyright 2024 Google Inc. All rights reserved. | 
 | // http://ceres-solver.org/ | 
 | // | 
 | // Redistribution and use in source and binary forms, with or without | 
 | // modification, are permitted provided that the following conditions are met: | 
 | // | 
 | // * Redistributions of source code must retain the above copyright notice, | 
 | //   this list of conditions and the following disclaimer. | 
 | // * Redistributions in binary form must reproduce the above copyright notice, | 
 | //   this list of conditions and the following disclaimer in the documentation | 
 | //   and/or other materials provided with the distribution. | 
 | // * Neither the name of Google Inc. nor the names of its contributors may be | 
 | //   used to endorse or promote products derived from this software without | 
 | //   specific prior written permission. | 
 | // | 
 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
 | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | 
 | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | 
 | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | 
 | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | 
 | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | 
 | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | 
 | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | 
 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | 
 | // POSSIBILITY OF SUCH DAMAGE. | 
 | // | 
 | // Author: markshachkov@gmail.com (Mark Shachkov) | 
 |  | 
 | #include "ceres/cuda_sparse_cholesky.h" | 
 |  | 
 | #ifndef CERES_NO_CUDSS | 
 |  | 
 | #include <cstddef> | 
 | #include <iostream> | 
 | #include <memory> | 
 | #include <string> | 
 | #include <type_traits> | 
 |  | 
 | #include "Eigen/Core" | 
 | #include "absl/log/check.h" | 
 | #include "absl/log/log.h" | 
 | #include "ceres/compressed_row_sparse_matrix.h" | 
 | #include "ceres/cuda_buffer.h" | 
 | #include "ceres/linear_solver.h" | 
 | #include "cudss.h" | 
 |  | 
 | namespace ceres::internal { | 
 |  | 
 | inline std::string cuDSSStatusToString(cudssStatus_t status) { | 
 |   switch (status) { | 
 |     case CUDSS_STATUS_SUCCESS: | 
 |       return "CUDSS_STATUS_SUCCESS"; | 
 |     case CUDSS_STATUS_NOT_INITIALIZED: | 
 |       return "CUDSS_STATUS_NOT_INITIALIZED"; | 
 |     case CUDSS_STATUS_ALLOC_FAILED: | 
 |       return "CUDSS_STATUS_ALLOC_FAILED"; | 
 |     case CUDSS_STATUS_INVALID_VALUE: | 
 |       return "CUDSS_STATUS_INVALID_VALUE"; | 
 |     case CUDSS_STATUS_NOT_SUPPORTED: | 
 |       return "CUDSS_STATUS_NOT_SUPPORTED"; | 
 |     case CUDSS_STATUS_EXECUTION_FAILED: | 
 |       return "CUDSS_STATUS_EXECUTION_FAILED"; | 
 |     case CUDSS_STATUS_INTERNAL_ERROR: | 
 |       return "CUDSS_STATUS_INTERNAL_ERROR"; | 
 |     default: | 
 |       return "unknown cuDSS status: " + std::to_string(status); | 
 |   } | 
 | } | 
 |  | 
 | #define CUDSS_STATUS_CHECK(IN)                                     \ | 
 |   if (cudssStatus_t status = IN; status != CUDSS_STATUS_SUCCESS) { \ | 
 |     CHECK(false) << "Got error: " << cuDSSStatusToString(status);  \ | 
 |   } | 
 |  | 
 | #define CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR(IN, additional_message)    \ | 
 |   if (cudssStatus_t status = IN; status != CUDSS_STATUS_SUCCESS) {       \ | 
 |     *message = std::string(additional_message) +                         \ | 
 |                " Got error: " + cuDSSStatusToString(status);             \ | 
 |     return factorize_result_ = LinearSolverTerminationType::FATAL_ERROR; \ | 
 |   } | 
 |  | 
 | #define CUDSS_STATUS_OK_OR_RETURN_CUDSS_STATUS(IN)                 \ | 
 |   if (cudssStatus_t status = IN; status != CUDSS_STATUS_SUCCESS) { \ | 
 |     return status;                                                 \ | 
 |   } | 
 |  | 
 | class CERES_NO_EXPORT CuDSSMatrixBase { | 
 |  public: | 
 |   CuDSSMatrixBase() = default; | 
 |   CuDSSMatrixBase(const CuDSSMatrixBase&) = delete; | 
 |   CuDSSMatrixBase(CuDSSMatrixBase&&) = delete; | 
 |   CuDSSMatrixBase& operator=(const CuDSSMatrixBase&) = delete; | 
 |   CuDSSMatrixBase& operator=(CuDSSMatrixBase&&) = delete; | 
 |   ~CuDSSMatrixBase() { CUDSS_STATUS_CHECK(Free()); } | 
 |  | 
 |   cudssStatus_t Free() noexcept { | 
 |     if (matrix_) { | 
 |       const auto status = cudssMatrixDestroy(matrix_); | 
 |       matrix_ = nullptr; | 
 |       return status; | 
 |     } | 
 |  | 
 |     return CUDSS_STATUS_SUCCESS; | 
 |   } | 
 |  | 
 |   cudssMatrix_t Get() const noexcept { return matrix_; } | 
 |  | 
 |  protected: | 
 |   cudssMatrix_t matrix_{nullptr}; | 
 | }; | 
 |  | 
 | class CERES_NO_EXPORT CuDSSMatrixCSR : public CuDSSMatrixBase { | 
 |  public: | 
 |   cudssStatus_t Reset(int64_t num_rows, | 
 |                       int64_t num_cols, | 
 |                       int64_t num_nonzeros, | 
 |                       void* rows_start, | 
 |                       void* rows_end, | 
 |                       void* cols, | 
 |                       void* values, | 
 |                       cudaDataType_t index_type, | 
 |                       cudaDataType_t value_type, | 
 |                       cudssMatrixType_t matrix_type, | 
 |                       cudssMatrixViewType_t matrix_storage_type, | 
 |                       cudssIndexBase_t index_base) { | 
 |     CUDSS_STATUS_OK_OR_RETURN_CUDSS_STATUS(Free()); | 
 |  | 
 |     return cudssMatrixCreateCsr(&matrix_, | 
 |                                 num_rows, | 
 |                                 num_cols, | 
 |                                 num_nonzeros, | 
 |                                 rows_start, | 
 |                                 rows_end, | 
 |                                 cols, | 
 |                                 values, | 
 |                                 index_type, | 
 |                                 value_type, | 
 |                                 matrix_type, | 
 |                                 matrix_storage_type, | 
 |                                 index_base); | 
 |   } | 
 | }; | 
 |  | 
 | class CERES_NO_EXPORT CuDSSMatrixDense : public CuDSSMatrixBase { | 
 |  public: | 
 |   cudssStatus_t Reset(int64_t num_rows, | 
 |                       int64_t num_cols, | 
 |                       int64_t leading_dimension_size, | 
 |                       void* values, | 
 |                       cudaDataType_t value_type, | 
 |                       cudssLayout_t layout) { | 
 |     CUDSS_STATUS_OK_OR_RETURN_CUDSS_STATUS(Free()); | 
 |  | 
 |     return cudssMatrixCreateDn(&matrix_, | 
 |                                num_rows, | 
 |                                num_cols, | 
 |                                leading_dimension_size, | 
 |                                values, | 
 |                                value_type, | 
 |                                layout); | 
 |   } | 
 | }; | 
 |  | 
 | struct CudssContext { | 
 |   CudssContext(cudssHandle_t cudss_handle) : cudss_handle_(cudss_handle) { | 
 |     CUDSS_STATUS_CHECK(cudssConfigCreate(&solver_config_)); | 
 |     CUDSS_STATUS_CHECK(cudssDataCreate(cudss_handle_, &solver_data_)); | 
 |   } | 
 |   CudssContext(const CudssContext&) = delete; | 
 |   CudssContext(CudssContext&&) = delete; | 
 |   CudssContext& operator=(const CudssContext&) = delete; | 
 |   CudssContext& operator=(CudssContext&&) = delete; | 
 |   ~CudssContext() { | 
 |     CUDSS_STATUS_CHECK(cudssDataDestroy(cudss_handle_, solver_data_)); | 
 |     CUDSS_STATUS_CHECK(cudssConfigDestroy(solver_config_)); | 
 |   } | 
 |  | 
 |   cudssHandle_t cudss_handle_{nullptr}; | 
 |   cudssConfig_t solver_config_{nullptr}; | 
 |   cudssData_t solver_data_{nullptr}; | 
 | }; | 
 |  | 
 | template <typename Scalar> | 
 | class CERES_NO_EXPORT CudaSparseCholeskyImpl final : public SparseCholesky { | 
 |  public: | 
 |   static_assert(std::is_same_v<Scalar, float> || std::is_same_v<Scalar, double>, | 
 |                 "Scalar type is unsupported by cuDSS"); | 
 |   static constexpr cudaDataType_t kCuDSSScalar = | 
 |       std::is_same_v<Scalar, float> ? CUDA_R_32F : CUDA_R_64F; | 
 |  | 
 |   CudaSparseCholeskyImpl(ContextImpl* context) | 
 |       : context_(context), | 
 |         lhs_cols_d_(context_), | 
 |         lhs_rows_d_(context_), | 
 |         lhs_values_d_(context_), | 
 |         rhs_d_(context_), | 
 |         x_d_(context_) {} | 
 |   CudaSparseCholeskyImpl(const CudaSparseCholeskyImpl&) = delete; | 
 |   CudaSparseCholeskyImpl(CudaSparseCholeskyImpl&&) = delete; | 
 |   CudaSparseCholeskyImpl& operator=(const CudaSparseCholeskyImpl&) = delete; | 
 |   CudaSparseCholeskyImpl& operator=(CudaSparseCholeskyImpl&&) = delete; | 
 |   ~CudaSparseCholeskyImpl() = default; | 
 |  | 
 |   CompressedRowSparseMatrix::StorageType StorageType() const { | 
 |     return CompressedRowSparseMatrix::StorageType::LOWER_TRIANGULAR; | 
 |   } | 
 |  | 
 |   LinearSolverTerminationType Factorize(CompressedRowSparseMatrix* lhs, | 
 |                                         std::string* message) { | 
 |     if (lhs->num_rows() != lhs->num_cols()) { | 
 |       *message = "lhs matrix must be square"; | 
 |       return factorize_result_ = LinearSolverTerminationType::FATAL_ERROR; | 
 |     } | 
 |     if (lhs->storage_type() != StorageType()) { | 
 |       *message = "lhs matrix must be lower triangular"; | 
 |       return factorize_result_ = LinearSolverTerminationType::FATAL_ERROR; | 
 |     } | 
 |  | 
 |     // If, after previous attempt to factorize, cudssDataGet(CUDSS_DATA_INFO) | 
 |     // returned a numerical error, such error will be preserved by cuDSS 0.3.0 | 
 |     // and returned by cudssDataGet(CUDSS_DATA_INFO) even after correctly | 
 |     // factorizeable matrix is provided. Such behaviour forces us to reset a | 
 |     // cudssData_t object (managed by CudssContext) and to loose a result of | 
 |     // anylyze stage, thus we have to perform analyze one more time. | 
 |     // TODO: do not re-perform analyze in case of failed factorization numerics | 
 |     if (analyze_result_ != LinearSolverTerminationType::SUCCESS || | 
 |         factorize_result_ != LinearSolverTerminationType::SUCCESS) { | 
 |       analyze_result_ = Analyze(lhs, message); | 
 |       if (analyze_result_ != LinearSolverTerminationType::SUCCESS) { | 
 |         return analyze_result_; | 
 |       } | 
 |     } | 
 |     CHECK_NE(cudss_context_.get(), nullptr); | 
 |  | 
 |     ConvertAndCopyToDevice(lhs->values(), lhs_values_h_.data(), lhs_values_d_); | 
 |  | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         cudssExecute(context_->cudss_handle_, | 
 |                      CUDSS_PHASE_FACTORIZATION, | 
 |                      cudss_context_->solver_config_, | 
 |                      cudss_context_->solver_data_, | 
 |                      cudss_lhs_.Get(), | 
 |                      cudss_x_.Get(), | 
 |                      cudss_rhs_.Get()), | 
 |         "cudssExecute with CUDSS_PHASE_FACTORIZATION failed"); | 
 |  | 
 |     int cudss_data_info; | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         GetCudssDataInfo(cudss_data_info), | 
 |         "cudssDataGet with CUDSS_DATA_INFO failed"); | 
 |     const auto factorization_status = | 
 |         static_cast<cudssStatus_t>(cudss_data_info); | 
 |  | 
 |     if (factorization_status == CUDSS_STATUS_SUCCESS) { | 
 |       return factorize_result_ = LinearSolverTerminationType::SUCCESS; | 
 |     } | 
 |  | 
 |     if (cudss_data_info > 0) { | 
 |       return factorize_result_ = LinearSolverTerminationType::FAILURE; | 
 |     } | 
 |  | 
 |     return factorize_result_ = LinearSolverTerminationType::FATAL_ERROR; | 
 |   } | 
 |  | 
 |   LinearSolverTerminationType Solve(const double* rhs, | 
 |                                     double* solution, | 
 |                                     std::string* message) { | 
 |     CHECK_NE(cudss_context_.get(), nullptr); | 
 |  | 
 |     if (factorize_result_ != LinearSolverTerminationType::SUCCESS) { | 
 |       *message = "Factorize did not complete successfully previously."; | 
 |       return factorize_result_; | 
 |     } | 
 |  | 
 |     ConvertAndCopyToDevice(rhs, rhs_h_.data(), rhs_d_); | 
 |  | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         cudssExecute(context_->cudss_handle_, | 
 |                      CUDSS_PHASE_SOLVE, | 
 |                      cudss_context_->solver_config_, | 
 |                      cudss_context_->solver_data_, | 
 |                      cudss_lhs_.Get(), | 
 |                      cudss_x_.Get(), | 
 |                      cudss_rhs_.Get()), | 
 |         "cudssExecute with CUDSS_PHASE_SOLVE failed"); | 
 |  | 
 |     ConvertAndCopyToHost(x_d_, x_h_.data(), solution); | 
 |  | 
 |     int cudss_data_info; | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         GetCudssDataInfo(cudss_data_info), | 
 |         "cudssDataGet with CUDSS_DATA_INFO failed"); | 
 |     const auto solve_status = static_cast<cudssStatus_t>(cudss_data_info); | 
 |  | 
 |     if (solve_status != CUDSS_STATUS_SUCCESS) { | 
 |       return LinearSolverTerminationType::FAILURE; | 
 |     } | 
 |  | 
 |     return LinearSolverTerminationType::SUCCESS; | 
 |   } | 
 |  | 
 |  private: | 
 |   cudssStatus_t GetCudssDataInfo(int& cudss_data_info) { | 
 |     CHECK_NE(cudss_context_.get(), nullptr); | 
 |     std::size_t size_written = 0; | 
 |     CUDSS_STATUS_OK_OR_RETURN_CUDSS_STATUS( | 
 |         cudssDataGet(context_->cudss_handle_, | 
 |                      cudss_context_->solver_data_, | 
 |                      CUDSS_DATA_INFO, | 
 |                      &cudss_data_info, | 
 |                      sizeof(cudss_data_info), | 
 |                      &size_written)); | 
 |     // TODO: enable following check after cudssDataGet will be fixed | 
 |     // CHECK_EQ(size_written, sizeof(cudss_data_info)); | 
 |  | 
 |     return CUDSS_STATUS_SUCCESS; | 
 |   } | 
 |  | 
 |   LinearSolverTerminationType Analyze(const CompressedRowSparseMatrix* lhs, | 
 |                                       std::string* message) { | 
 |     if (auto status = SetupCudssMatrices(lhs, message); | 
 |         status != LinearSolverTerminationType::SUCCESS) { | 
 |       return status; | 
 |     } | 
 |  | 
 |     lhs_rows_d_.CopyFromCpu(lhs->rows(), lhs->num_rows() + 1); | 
 |     lhs_cols_d_.CopyFromCpu(lhs->cols(), lhs->num_nonzeros()); | 
 |  | 
 |     // Analyze and factorization results are stored in cudssData_t (managed by | 
 |     // CudssContext). Given that cuDSS 0.3.0 does not reset it's error state in | 
 |     // case of failed numerics at factorization stage, we have to reset | 
 |     // cudssData_t and to recompute an analyze stage while trying to factorize a | 
 |     // rescaled matrix with the same structure. | 
 |     // TODO: move creation of CudssContext to ctor of CudaSparseCholeskyImpl | 
 |     cudss_context_ = std::make_unique<CudssContext>(context_->cudss_handle_); | 
 |  | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         cudssExecute(context_->cudss_handle_, | 
 |                      CUDSS_PHASE_ANALYSIS, | 
 |                      cudss_context_->solver_config_, | 
 |                      cudss_context_->solver_data_, | 
 |                      cudss_lhs_.Get(), | 
 |                      cudss_x_.Get(), | 
 |                      cudss_rhs_.Get()), | 
 |         "cudssExecute with CUDSS_PHASE_ANALYSIS failed"); | 
 |  | 
 |     return LinearSolverTerminationType::SUCCESS; | 
 |   } | 
 |  | 
 |   // Resize buffers and setup cuDSS structs that describe the type and storage | 
 |   // configuration of linear system operands. | 
 |   LinearSolverTerminationType SetupCudssMatrices( | 
 |       const CompressedRowSparseMatrix* lhs, std::string* message) { | 
 |     const auto num_rows = lhs->num_rows(); | 
 |     const auto num_nonzeros = lhs->num_nonzeros(); | 
 |  | 
 |     if constexpr (std::is_same_v<Scalar, float>) { | 
 |       lhs_values_h_.Reserve(num_nonzeros); | 
 |       rhs_h_.Reserve(num_rows); | 
 |       x_h_.Reserve(num_rows); | 
 |     } | 
 |  | 
 |     lhs_rows_d_.Reserve(num_rows + 1); | 
 |     lhs_cols_d_.Reserve(num_nonzeros); | 
 |     lhs_values_d_.Reserve(num_nonzeros); | 
 |     rhs_d_.Reserve(num_rows); | 
 |     x_d_.Reserve(num_rows); | 
 |  | 
 |     static constexpr auto kFailedToCreateCuDSSMatrix = | 
 |         "cudssMatrixCreate() call failed"; | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR(cudss_lhs_.Reset(num_rows, | 
 |                                                            num_rows, | 
 |                                                            num_nonzeros, | 
 |                                                            lhs_rows_d_.data(), | 
 |                                                            nullptr, | 
 |                                                            lhs_cols_d_.data(), | 
 |                                                            lhs_values_d_.data(), | 
 |                                                            CUDA_R_32I, | 
 |                                                            kCuDSSScalar, | 
 |                                                            CUDSS_MTYPE_SPD, | 
 |                                                            CUDSS_MVIEW_LOWER, | 
 |                                                            CUDSS_BASE_ZERO), | 
 |                                           kFailedToCreateCuDSSMatrix); | 
 |  | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         cudss_rhs_.Reset(num_rows, | 
 |                          1, | 
 |                          num_rows, | 
 |                          rhs_d_.data(), | 
 |                          kCuDSSScalar, | 
 |                          CUDSS_LAYOUT_COL_MAJOR), | 
 |         kFailedToCreateCuDSSMatrix); | 
 |  | 
 |     CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR( | 
 |         cudss_x_.Reset(num_rows, | 
 |                        1, | 
 |                        num_rows, | 
 |                        x_d_.data(), | 
 |                        kCuDSSScalar, | 
 |                        CUDSS_LAYOUT_COL_MAJOR), | 
 |         kFailedToCreateCuDSSMatrix); | 
 |  | 
 |     return LinearSolverTerminationType::SUCCESS; | 
 |   } | 
 |  | 
 |   template <typename S, typename D> | 
 |   void Convert(const S* source, D* destination, size_t size) { | 
 |     Eigen::Map<Eigen::Matrix<D, Eigen::Dynamic, 1>>(destination, size) = | 
 |         Eigen::Map<const Eigen::Matrix<S, Eigen::Dynamic, 1>>(source, size) | 
 |             .template cast<D>(); | 
 |   } | 
 |  | 
 |   void ConvertAndCopyToDevice(const double* source, | 
 |                               Scalar* intermediate, | 
 |                               CudaBuffer<Scalar>& destination) { | 
 |     const auto size = destination.size(); | 
 |     if constexpr (std::is_same_v<Scalar, double>) { | 
 |       destination.CopyFromCpu(source, size); | 
 |     } else { | 
 |       Convert(source, intermediate, size); | 
 |       destination.CopyFromCpu(intermediate, size); | 
 |     } | 
 |   } | 
 |  | 
 |   void ConvertAndCopyToHost(const CudaBuffer<Scalar>& source, | 
 |                             Scalar* intermediate, | 
 |                             double* destination) { | 
 |     const auto size = source.size(); | 
 |     if constexpr (std::is_same_v<Scalar, double>) { | 
 |       source.CopyToCpu(destination, source.size()); | 
 |     } else { | 
 |       source.CopyToCpu(intermediate, source.size()); | 
 |       Convert(intermediate, destination, size); | 
 |     } | 
 |   } | 
 |  | 
 |   ContextImpl* context_{nullptr}; | 
 |   std::unique_ptr<CudssContext> cudss_context_; | 
 |   CuDSSMatrixCSR cudss_lhs_; | 
 |   CuDSSMatrixDense cudss_rhs_; | 
 |   CuDSSMatrixDense cudss_x_; | 
 |  | 
 |   CudaPinnedHostBuffer<Scalar> lhs_values_h_; | 
 |   CudaPinnedHostBuffer<Scalar> rhs_h_; | 
 |   CudaPinnedHostBuffer<Scalar> x_h_; | 
 |   CudaBuffer<int> lhs_rows_d_; | 
 |   CudaBuffer<int> lhs_cols_d_; | 
 |   CudaBuffer<Scalar> lhs_values_d_; | 
 |   CudaBuffer<Scalar> rhs_d_; | 
 |   CudaBuffer<Scalar> x_d_; | 
 |  | 
 |   LinearSolverTerminationType analyze_result_ = | 
 |       LinearSolverTerminationType::FATAL_ERROR; | 
 |   LinearSolverTerminationType factorize_result_ = | 
 |       LinearSolverTerminationType::FATAL_ERROR; | 
 | }; | 
 |  | 
 | template <typename Scalar> | 
 | std::unique_ptr<SparseCholesky> CudaSparseCholesky<Scalar>::Create( | 
 |     ContextImpl* context, const OrderingType ordering_type) { | 
 |   if (ordering_type == OrderingType::NESDIS) { | 
 |     LOG(FATAL) | 
 |         << "Congratulations you have found a bug in Ceres Solver. Please " | 
 |            "report it to the Ceres Solver developers."; | 
 |     return nullptr; | 
 |   } | 
 |  | 
 |   if (context == nullptr || !context->IsCudaInitialized()) { | 
 |     LOG(FATAL) << "CudaSparseCholesky requires CUDA context to be initialized"; | 
 |     return nullptr; | 
 |   } | 
 |  | 
 |   return std::make_unique<CudaSparseCholeskyImpl<Scalar>>(context); | 
 | } | 
 |  | 
 | template class CudaSparseCholesky<float>; | 
 | template class CudaSparseCholesky<double>; | 
 |  | 
 | }  // namespace ceres::internal | 
 |  | 
 | #undef CUDSS_STATUS_CHECK | 
 | #undef CUDSS_STATUS_OK_OR_RETURN_FATAL_ERROR | 
 | #undef CUDSS_STATUS_OK_OR_RETURN_CUDSS_STATUS | 
 |  | 
 | #endif  // CERES_NO_CUDSS |