Skip to content

Commit

Permalink
Use CUDA.jl 5.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 20, 2023
1 parent d406df3 commit 1787498
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ uuid = "a8cc9031-bad2-4722-94f5-40deabb4245c"
version = "0.2.3"

[compat]
CUDA = "4, 5.1"
CUDA = "4, 5.1.1"
KLU = "0.3, 0.4"
julia = "1.6"

Expand Down
47 changes: 4 additions & 43 deletions src/backsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,6 @@ function _cu_matrix_description(A::CUSPARSE.CuSparseMatrixCSR, uplo, diag, index
return desc
end

# TODO: Add these constructors in CUDA.jl
#----------------------------------------------------------------------------------------------------------------#
mutable struct CuDenseVectorDescriptor2
handle::CUSPARSE.cusparseDnVecDescr_t

function CuDenseVectorDescriptor2(T::DataType, n::Int)
desc_ref = Ref{CUSPARSE.cusparseDnVecDescr_t}()
CUSPARSE.cusparseCreateDnVec(desc_ref, n, CUDA.CU_NULL, T)
obj = new(desc_ref[])
finalizer(CUSPARSE.cusparseDestroyDnVec, obj)
obj
end
end

Base.unsafe_convert(::Type{CUSPARSE.cusparseDnVecDescr_t}, desc::CuDenseVectorDescriptor2) = desc.handle

mutable struct CuDenseMatrixDescriptor2
handle::CUSPARSE.cusparseDnMatDescr_t

function CuDenseMatrixDescriptor2(T::DataType, m::Int, n::Int)
desc_ref = Ref{CUSPARSE.cusparseDnMatDescr_t}()
CUSPARSE.cusparseCreateDnMat(desc_ref, m, n, m, CUDA.CU_NULL, T, 'C')
obj = new(desc_ref[])
finalizer(CUSPARSE.cusparseDestroyDnMat, obj)
obj
end
end

Base.unsafe_convert(::Type{CUSPARSE.cusparseDnMatDescr_t}, desc::CuDenseMatrixDescriptor2) = desc.handle

struct CuSparseSV <: AbstractBacksolve
n::Int
algo::CUSPARSE.cusparseSpSVAlg_t
Expand All @@ -50,15 +20,6 @@ struct CuSparseSV <: AbstractBacksolve
bufferU::CuVector{UInt8}
end

function cusparseSpSV_updateMatrix(handle, spsvDescr, newValues, updatePart)
CUDA.initialize_context()
@ccall CUSPARSE.libcusparse.cusparseSpSV_updateMatrix(handle::CUSPARSE.cusparseHandle_t,
spsvDescr::CUSPARSE.cusparseSpSVDescr_t,
newValues::CUDA.CuPtr{Cvoid},
updatePart::CUSPARSE.cusparseSpSVUpdate_t)::CUSPARSE.cusparseStatus_t
end
#----------------------------------------------------------------------------------------------------------------#

function CuSparseSV(
A::CUSPARSE.CuSparseMatrixCSR{T}, transa::CUSPARSE.SparseChar;
algo=CUSPARSE.CUSPARSE_SPSV_ALG_DEFAULT,
Expand All @@ -79,7 +40,7 @@ function CuSparseSV(
alpha = one(T)

# Dummy descriptor
descX = CuDenseVectorDescriptor2(T, n)
descX = CUSPARSE.CuDenseVectorDescriptor(T, n)

# Descriptor for lower-triangular SpSV operation
spsv_L = CUSPARSE.CuSparseSpSVDescriptor()
Expand Down Expand Up @@ -118,8 +79,8 @@ function backsolve!(s::CuSparseSV, A::CUSPARSE.CuSparseMatrixCSR{T}, X::CuVector
alpha = one(T)

descX = CUSPARSE.CuDenseVectorDescriptor(X)
cusparseSpSV_updateMatrix(CUSPARSE.handle(), s.infoL, A.nzVal, CUSPARSE.CUSPARSE_SPSV_UPDATE_GENERAL)
cusparseSpSV_updateMatrix(CUSPARSE.handle(), s.infoU, A.nzVal, CUSPARSE.CUSPARSE_SPSV_UPDATE_GENERAL)
CUSPARSE.cusparseSpSV_updateMatrix(CUSPARSE.handle(), s.infoL, A.nzVal, CUSPARSE.CUSPARSE_SPSV_UPDATE_GENERAL)
CUSPARSE.cusparseSpSV_updateMatrix(CUSPARSE.handle(), s.infoU, A.nzVal, CUSPARSE.CUSPARSE_SPSV_UPDATE_GENERAL)

if s.transa == 'N'
CUSPARSE.cusparseSpSV_solve(
Expand Down Expand Up @@ -175,7 +136,7 @@ function CuSparseSM(
transx = 'N'
mX, nX = size(X)
@assert m == mX
descX = CuDenseMatrixDescriptor2(T, mX, nX)
descX = CUSPARSE.CuDenseMatrixDescriptor(T, mX, nX)

# Descriptor for lower-triangular SpSM operation
spsm_L = CUSPARSE.CuSparseSpSMDescriptor()
Expand Down
4 changes: 2 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function LinearAlgebra.lu!(rf::RFLU, J::CuSparseMatrixCSR)
alpha = one(T)
# Update rf.dsm
dsm = rf.dsm
descX = CuDenseMatrixDescriptor2(T, dsm.n, dsm.nrhs)
descX = CUSPARSE.CuDenseMatrixDescriptor(T, dsm.n, dsm.nrhs)
CUSPARSE.cusparseSpSM_analysis(
CUSPARSE.handle(), dsm.transa, 'N', Ref{T}(alpha), dsm.descL, descX, descX, T, dsm.algo, dsm.infoL, dsm.bufferL,
)
Expand All @@ -100,7 +100,7 @@ function LinearAlgebra.lu!(rf::RFLU, J::CuSparseMatrixCSR)
)
# Update rf.tsm
tsm = rf.tsm
descX = CuDenseMatrixDescriptor2(T, tsm.n, tsm.nrhs)
descX = CUSPARSE.CuDenseMatrixDescriptor(T, tsm.n, tsm.nrhs)
CUSPARSE.cusparseSpSM_analysis(
CUSPARSE.handle(), tsm.transa, 'N', Ref{T}(alpha), tsm.descL, descX, descX, T, tsm.algo, tsm.infoL, tsm.bufferL,
)
Expand Down

0 comments on commit 1787498

Please sign in to comment.