diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index 6f6ee164..653303aa 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -585,16 +585,25 @@ end # For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged? function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...) + return _orthogonalize_edges(tn, [edge]; kwargs...) +end + +# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged? +function _orthogonalize_edges( + tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs... +) # tn = factorize(tn, edge; kwargs...) # # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))` # new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge))) # return contract(tn, new_vertex => dst(edge)) tn = copy(tn) - left_inds = uniqueinds(tn, edge) - ltags = tags(tn, edge) - X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...) - tn[src(edge)] = X - tn[dst(edge)] *= Y + for edge in edges + left_inds = uniqueinds(tn, edge) + ltags = tags(tn, edge) + X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...) + tn[src(edge)] = X + tn[dst(edge)] *= Y + end return tn end @@ -602,19 +611,35 @@ function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge return _orthogonalize_edge(tn, edge; kwargs...) end +function ITensorMPS.orthogonalize( + tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs... +) + return _orthogonalize_edges(tn, edges; kwargs...) +end + function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...) return orthogonalize(tn, edgetype(tn)(edge); kwargs...) end -# Orthogonalize an ITensorNetwork towards a source vertex, treating +function ITensorMPS.orthogonalize( + tn::AbstractITensorNetwork, edges::Vector{Pair}; kwargs... +) + return orthogonalize(tn, edgetype(tn).(edges); kwargs...) +end + +# Orthogonalize an ITensorNetwork towards a region, treating # the network as a tree spanned by a spanning tree. # TODO: Rename `tree_orthogonalize`. -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, source_vertex) - spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex) - for e in spanning_tree_edges - ψ = orthogonalize(ψ, e) - end - return ψ +function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, region::Vector) + spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region)) + spanning_tree_edges = filter( + e -> !(src(e) ∈ region && dst(e) ∈ region), spanning_tree_edges + ) + return orthogonalize(ψ, spanning_tree_edges) +end + +function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, region) + return orthogonalize(ψ, [region]) end # TODO: decide whether to use graph mutating methods when resulting graph is unchanged? diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl index 97241c20..6ad096b4 100644 --- a/src/solvers/alternating_update/region_update.jl +++ b/src/solvers/alternating_update/region_update.jl @@ -1,35 +1,3 @@ -#ToDo: generalize beyond 2-site -#ToDo: remove concept of orthogonality center for generality -function current_ortho(sweep_plan, which_region_update) - regions = first.(sweep_plan) - region = regions[which_region_update] - current_verts = support(region) - if !isa(region, AbstractEdge) && length(region) == 1 - return only(current_verts) - end - # look forward - other_regions = filter( - x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) - ) - # find the first region that has overlapping support with current region - ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - if isnothing(ind) - # look backward - other_regions = reverse( - filter( - x -> !(issetequal(x, current_verts)), support.(regions[1:(which_region_update - 1)]) - ), - ) - ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - end - @assert !isnothing(ind) - future_verts = union(support(other_regions[ind])) - # return ortho_ceter as the vertex in current region that does not overlap with following one - overlapping_vertex = intersect(current_verts, future_verts) - nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) - return nonoverlapping_vertex -end - function region_update( projected_operator, state; @@ -55,14 +23,13 @@ function region_update( # ToDo: remove orthogonality center on vertex for generality # region carries same information - ortho_vertex = current_ortho(sweep_plan, which_region_update) if !isnothing(transform_operator) projected_operator = transform_operator( state, projected_operator; outputlevel, transform_operator_kwargs... ) end state, projected_operator, phi = extracter( - state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs + state, projected_operator, region; extracter_kwargs..., internal_kwargs ) # create references, in case solver does (out-of-place) modify PH or state state! = Ref(state) @@ -88,9 +55,7 @@ function region_update( # drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees... # so noiseterm is a solver #end - state, spec = inserter( - state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs - ) + state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs) all_kwargs = (; which_region_update, sweep_plan, diff --git a/src/solvers/extract/extract.jl b/src/solvers/extract/extract.jl index feb57c2f..431394ea 100644 --- a/src/solvers/extract/extract.jl +++ b/src/solvers/extract/extract.jl @@ -7,18 +7,19 @@ # insert_local_tensors takes that tensor and factorizes it back # apart and puts it back into the network. # -function default_extracter(state, projected_operator, region, ortho; internal_kwargs) - state = orthogonalize(state, ortho) +function default_extracter(state, projected_operator, region; internal_kwargs) if isa(region, AbstractEdge) - other_vertex = only(setdiff(support(region), [ortho])) - left_inds = uniqueinds(state[ortho], state[other_vertex]) + vsrc, vdst = src(region), dst(region) + state = orthogonalize(state, vsrc) + left_inds = uniqueinds(state[vsrc], state[vdst]) #ToDo: replace with call to factorize U, S, V = svd( - state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region) + state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region) ) - state[ortho] = U + state[vsrc] = U local_tensor = S * V else + state = orthogonalize(state, region) local_tensor = prod(state[v] for v in region) end projected_operator = position(projected_operator, state, region) diff --git a/src/solvers/insert/insert.jl b/src/solvers/insert/insert.jl index 11aed223..42b32023 100644 --- a/src/solvers/insert/insert.jl +++ b/src/solvers/insert/insert.jl @@ -6,8 +6,7 @@ function default_inserter( state::AbstractTTN, phi::ITensor, - region, - ortho_vert; + region; normalize=false, maxdim=nothing, mindim=nothing, @@ -16,16 +15,14 @@ function default_inserter( ) state = copy(state) spec = nothing - other_vertex = setdiff(support(region), [ortho_vert]) - if !isempty(other_vertex) - v = only(other_vertex) - e = edgetype(state)(ortho_vert, v) - indsTe = inds(state[ortho_vert]) + if length(region) == 2 + v = last(region) + e = edgetype(state)(first(region), last(region)) + indsTe = inds(state[first(region)]) L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff) - state[ortho_vert] = L - + state[first(region)] = L else - v = ortho_vert + v = only(region) end state[v] = phi state = set_ortho_region(state, [v]) @@ -44,8 +41,7 @@ function default_inserter( normalize=false, internal_kwargs, ) - v = only(setdiff(support(region), [ortho])) - state[v] *= phi - state = set_ortho_region(state, [v]) + state[dst(region)] *= phi + state = set_ortho_region(state, [dst(region)]) return state, nothing end diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index c8dccb1f..13d78197 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -1,6 +1,12 @@ using Graphs: has_vertex using NamedGraphs.GraphsExtensions: - GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices + GraphsExtensions, + edge_path, + leaf_vertices, + post_order_dfs_edges, + post_order_dfs_vertices, + a_star +using NamedGraphs: namedgraph_a_star using IsApprox: IsApprox, Approx using ITensors: @Algorithm_str, directsum, hasinds, permute, plev using ITensorMPS: linkind, loginner, lognorm, orthogonalize @@ -29,30 +35,16 @@ function set_ortho_region(tn::AbstractTTN, new_region) return error("Not implemented") end -# -# Orthogonalization -# - -function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...) - if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn)) - return tn - end - # TODO: Rewrite this in a more general way. - if isone(length(ortho_region(tn))) - edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center) - else - edge_list = post_order_dfs_edges(tn, ortho_center) - end - for e in edge_list - tn = orthogonalize(tn, e) +function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...) + paths = [ + namedgraph_a_star(underlying_graph(ttn), rp, r) for r in region for + rp in ortho_region(ttn) + ] + path = unique(reduce(vcat, paths)) + if !isempty(path) + ttn = typeof(ttn)(orthogonalize(ITensorNetwork(ttn), path; kwargs...)) end - return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center])) -end - -# For ambiguity error - -function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...) - return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...)) + return set_ortho_region(ttn, region) end #