Skip to content

Commit

Permalink
Fix some more SVD issues on Metal
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 23, 2023
1 parent d60b852 commit 2b8deb4
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,21 @@ function svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT}

P = MS .^ 2
if truncate
P_cpu = NDTensors.cpu(P)
truncerr, _ = truncate!(
P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
P_cpu; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
P = adapt(typeof(P), P_cpu)
else
truncerr = 0.0
end
spec = Spectrum(P, truncerr)
dS = length(P)
if dS < length(MS)
MU = MU[:, 1:dS]
resize!(MS, dS)
# Fails on some GPU backends like Metal.
# resize!(MS, dS)
MS = MS[1:dS]
MV = MV[:, 1:dS]
end

Expand Down Expand Up @@ -236,11 +240,11 @@ function eigen(
VM = VM[:, p]

if truncate
cpu_dm = NDTensors.cpu(DM)
DM_cpu = NDTensors.cpu(DM)
truncerr, _ = truncate!(
cpu_dm; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
DM_cpu; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
DM = adapt(typeof(DM), cpu_dm)
DM = adapt(typeof(DM), DM_cpu)
dD = length(DM)
if dD < size(VM, 2)
VM = VM[:, 1:dD]
Expand Down Expand Up @@ -354,10 +358,13 @@ function eigen(
#DM = DM[p]
#VM = VM[:,p]


if truncate
DM_cpu = NDTensors.cpu(DM)
truncerr, _ = truncate!(
DM; maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
DM_cpu; maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
DM = adapt(typeof(DM), DM_cpu)
dD = length(DM)
if dD < size(VM, 2)
VM = VM[:, 1:dD]
Expand Down

0 comments on commit 2b8deb4

Please sign in to comment.