Skip to content

Commit

Permalink
[NDTensors] Replace axpby with broadcast when β == 0 (#1309)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Jan 20, 2024
1 parent 2c7f33d commit e1939b6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
6 changes: 2 additions & 4 deletions NDTensors/src/dense/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ function _contract_scalar_noperm!(
if iszero(α)
fill!(Rᵈ, 0)
else
# Rᵈ .= α .* T₂ᵈ
LinearAlgebra.axpby!(α, Tᵈ, β, Rᵈ)
Rᵈ .= α .* Tᵈ
end
elseif isone(β)
if iszero(α)
Expand All @@ -81,8 +80,7 @@ function _contract_scalar_noperm!(
end
else
if iszero(α)
# Rᵈ .= β .* Rᵈ
LinearAlgebra.scal!(length(Rᵈ), β, Rᵈ, 1)
Rᵈ .= β .* Rᵈ
else
# Rᵈ .= α .* Tᵈ .+ β .* Rᵈ
LinearAlgebra.axpby!(α, Tᵈ, β, Rᵈ)
Expand Down
26 changes: 26 additions & 0 deletions NDTensors/test/test_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,32 @@ using .NDTensorsTestUtils: devices_list
end

# Only CPU backend testing
@testset "Contract with exotic types" begin
# BigFloat is not supported on GPU
## randn(BigFloat, ...) is not defined in Julia 1.6
a = BigFloat.(randn(Float64, 2, 3))
t = Tensor(a, (1, 2, 3))
m = Tensor(a, (2, 3))
v = Tensor([one(BigFloat)], (1,))

@test m contract(t, (-1, 2, 3), v, (-1,))
tp = similar(t)
NDTensors.contract!(tp, (1, 2, 3), t, (1, 2, 3), v, (1,), false, false)
@test iszero(tp)

fill!(tp, one(BigFloat))
NDTensors.contract!(tp, (1, 2, 3), t, (1, 2, 3), v, (1,), false, true)
for i in tp
@test i == one(BigFloat)
end

rand_factor = BigFloat(randn(Float64))
NDTensors.contract!(tp, (1, 2, 3), t, (1, 2, 3), v, (1,), false, rand_factor)
for i in tp
@test i == rand_factor
end
end

@testset "change backends" begin
a, b, c = [randn(5, 5) for i in 1:3]
backend_auto()
Expand Down

0 comments on commit e1939b6

Please sign in to comment.