Add support for cuDSS 0.8 cuDSS 0.8 replaced cudaDataType_t with its own cudssDataType_t enum in the matrix creation APIs, and cudssMatrixCreateCsr() gained a separate parameter for the type of the row offsets. Select the appropriate enum type and pass the additional argument based on CUDSS_VERSION, so that the code compiles against both cuDSS 0.7 and 0.8. Also report the new CUDSS_STATUS_IR_FAILED status added in 0.8. Fixes https://github.com/ceres-solver/ceres-solver/issues/1203 Refs SW-8464 Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Change-Id: Ibf36261bd814f0f50e08caa5cf9391810a51540f
diff --git a/internal/ceres/cuda_sparse_cholesky.cc b/internal/ceres/cuda_sparse_cholesky.cc index 769d90e..689ab08 100644 --- a/internal/ceres/cuda_sparse_cholesky.cc +++ b/internal/ceres/cuda_sparse_cholesky.cc
@@ -48,6 +48,23 @@ namespace ceres::internal { +// cuDSS 0.8 replaced cudaDataType_t with cudssDataType_t in the matrix +// creation APIs. The enumerators of cudssDataType_t mirror the values of +// their cudaDataType_t counterparts, but the two enum types are not +// implicitly convertible, so pick the right type and enumerators at compile +// time to stay compatible with both versions of the API. +#if CUDSS_VERSION >= 800 +using CuDSSDataType = cudssDataType_t; +constexpr CuDSSDataType kCuDSSR32I = CUDSS_R_32I; +constexpr CuDSSDataType kCuDSSR32F = CUDSS_R_32F; +constexpr CuDSSDataType kCuDSSR64F = CUDSS_R_64F; +#else +using CuDSSDataType = cudaDataType_t; +constexpr CuDSSDataType kCuDSSR32I = CUDA_R_32I; +constexpr CuDSSDataType kCuDSSR32F = CUDA_R_32F; +constexpr CuDSSDataType kCuDSSR64F = CUDA_R_64F; +#endif // CUDSS_VERSION >= 800 + inline std::string cuDSSStatusToString(cudssStatus_t status) { switch (status) { case CUDSS_STATUS_SUCCESS: @@ -64,6 +81,10 @@ return "CUDSS_STATUS_EXECUTION_FAILED"; case CUDSS_STATUS_INTERNAL_ERROR: return "CUDSS_STATUS_INTERNAL_ERROR"; +#if CUDSS_VERSION >= 800 + case CUDSS_STATUS_IR_FAILED: + return "CUDSS_STATUS_IR_FAILED"; +#endif // CUDSS_VERSION >= 800 default: return "unknown cuDSS status: " + std::to_string(status); } @@ -120,8 +141,8 @@ void* rows_end, void* cols, void* values, - cudaDataType_t index_type, - cudaDataType_t value_type, + CuDSSDataType index_type, + CuDSSDataType value_type, cudssMatrixType_t matrix_type, cudssMatrixViewType_t matrix_storage_type, cudssIndexBase_t index_base) { @@ -135,6 +156,12 @@ rows_end, cols, values, +#if CUDSS_VERSION >= 800 + // cuDSS 0.8 takes the type of the row offsets + // separately from the type of the column + // indices; Ceres uses the same type for both. + index_type, +#endif // CUDSS_VERSION >= 800 index_type, value_type, matrix_type, @@ -149,7 +176,7 @@ int64_t num_cols, int64_t leading_dimension_size, void* values, - cudaDataType_t value_type, + CuDSSDataType value_type, cudssLayout_t layout) { CUDSS_STATUS_OK_OR_RETURN_CUDSS_STATUS(Free()); @@ -187,8 +214,8 @@ public: static_assert(std::is_same_v<Scalar, float> || std::is_same_v<Scalar, double>, "Scalar type is unsupported by cuDSS"); - static constexpr cudaDataType_t kCuDSSScalar = - std::is_same_v<Scalar, float> ? CUDA_R_32F : CUDA_R_64F; + static constexpr CuDSSDataType kCuDSSScalar = + std::is_same_v<Scalar, float> ? kCuDSSR32F : kCuDSSR64F; CudaSparseCholeskyImpl(ContextImpl* context) : context_(context), @@ -377,7 +404,7 @@ nullptr, lhs_cols_d_.data(), lhs_values_d_.data(), - CUDA_R_32I, + kCuDSSR32I, kCuDSSScalar, CUDSS_MTYPE_SPD, CUDSS_MVIEW_LOWER,