Skip to content

Commit

Permalink
Fix issues with GradedAxes
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 27, 2024
1 parent 849051b commit 87b2ee1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,18 @@ struct BlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
block::Tuple{Vararg{Block{1,Int},N}}
end
function Base.axes(a::BlockView)
# TODO: Try to avoid conversion to `Base.OneTo{Int}`, or just convert
# the element type to `Int` with `Int.(...)`.
# When the axes of `a.array` are `GradedOneTo`, the block is `LabelledUnitRange`,
# which has element type `LabelledInteger`. That causes conversion problems
# in some generic Base Julia code, for example when printing `BlockView`.
return ntuple(ndims(a)) do dim
return Base.OneTo{Int}(only(axes(axes(a.array, dim)[a.block[dim]])))
end
end
function Base.size(a::BlockView)
return ntuple(dim -> length(axes(a.array, dim)[a.block[dim]]), ndims(a))
return length.(axes(a))
end
function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
return blocks(a.array)[Int.(a.block)...][index...]
Expand Down
15 changes: 15 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/unitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ end
using NDTensors.LabelledNumbers: LabelledNumbers, label
LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a)))

using NDTensors.LabelledNumbers: LabelledUnitRange
# The Base version of `length(::AbstractUnitRange)` drops the label.
function Base.length(a::UnitRangeDual{<:Any,<:LabelledUnitRange})
return dual(length(nondual(a)))
end
function Base.iterate(a::UnitRangeDual, i)
i == last(a) && return nothing
return dual.(iterate(nondual(a), i))
end
# TODO: Is this a good definition?
Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a

using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))

using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock
BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a))
BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer}
return OrdinalRange{T,T}(unlabel(a))
end

# TODO: Is this a good definition?
Base.unitrange(a::LabelledUnitRange) = a

for f in [:first, :getindex, :last, :length, :step]
@eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a))
end
Expand Down

0 comments on commit 87b2ee1

Please sign in to comment.