From 4efceb86dee7b15bcd37f2bb5b171180c305deb3 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 22 May 2024 14:22:34 -0400 Subject: [PATCH] Generalize siteinds constructor (#186) * Generalize siteinds constructor * Bump to v0.11.12 --- Project.toml | 2 +- src/sitetype.jl | 25 ++++++++++++++++--------- test/test_sitetype.jl | 12 +++++++++++- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index e20bc55b..a2cd740d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.11.11" +version = "0.11.12" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/sitetype.jl b/src/sitetype.jl index 407fa2e4..570f711c 100644 --- a/src/sitetype.jl +++ b/src/sitetype.jl @@ -11,18 +11,25 @@ function ITensors.siteind(d::Integer, v; addtags="", kwargs...) return ITensors.addtags(Index(d; tags="Site, $addtags", kwargs...), vertex_tag(v)) end -function ITensors.siteinds(sitetypes::AbstractDictionary, g::AbstractGraph; kwargs...) - is = IndsNetwork(g) - for v in vertices(g) - is[v] = [siteind(sitetypes[v], vertex_tag(v); kwargs...)] - end - return is +to_siteinds_callable(x) = Returns(x) +function to_siteinds_callable(x::AbstractDictionary) + return Base.Fix1(getindex, x) ∘ keytype(x) +end + +function ITensors.siteinds(x, g::AbstractGraph; kwargs...) + return siteinds(to_siteinds_callable(x), g; kwargs...) end -function ITensors.siteinds(sitetype, g::AbstractGraph; kwargs...) - return siteinds(Dictionary(vertices(g), fill(sitetype, nv(g))), g; kwargs...) +function to_siteind(x, vertex; kwargs...) + return [siteind(x, vertex_tag(vertex); kwargs...)] end +to_siteind(x::Index, vertex; kwargs...) = [x] + function ITensors.siteinds(f::Function, g::AbstractGraph; kwargs...) - return siteinds(Dictionary(vertices(g), map(v -> f(v), vertices(g))), g; kwargs...) + is = IndsNetwork(g) + for v in vertices(g) + is[v] = to_siteind(f(v), v; kwargs...) + end + return is end diff --git a/test/test_sitetype.jl b/test/test_sitetype.jl index 77075d8d..33283e91 100644 --- a/test/test_sitetype.jl +++ b/test/test_sitetype.jl @@ -3,7 +3,7 @@ using DataGraphs: vertex_data using Dictionaries: Dictionary using Graphs: nv, vertices using ITensorNetworks: IndsNetwork, siteinds -using ITensors: SiteType, hastags, space +using ITensors: Index, SiteType, hastags, space using ITensors.NDTensors: dim using NamedGraphs.NamedGraphGenerators: named_grid using Test: @test, @testset @@ -18,6 +18,16 @@ using Test: @test, @testset fdim(v::Tuple) = space(SiteType(ftype(v))) testtag = "TestTag" + d1 = map(v -> Index(2), vertices(g)) + d2 = map(v -> "S=1/2", vertices(g)) + for x in (v -> d1[v], d1, v -> d2[v], d2) + s = siteinds(x, g) + @test s[1, 1] isa Vector{<:Index} + @test s[1, 2] isa Vector{<:Index} + @test s[2, 1] isa Vector{<:Index} + @test s[2, 2] isa Vector{<:Index} + end + # uniform string sitetype s_us = siteinds(sitetypes[1], g; addtags=testtag) @test s_us isa IndsNetwork