Skip to content

Commit

Permalink
Two site working
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 13, 2025
1 parent df7a4c5 commit e9e0336
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 129 deletions.
16 changes: 10 additions & 6 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,14 @@ function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
return ITensor[tensornetwork(bpc)[v] for v in verts]
end

function factor(bpc::AbstractBeliefPropagationCache, vertex::PartitionVertex)
return factors(bpc, vertices(bpc, vertex))
function factors(
bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:PartitionVertex}
)
return factors(bpc, vertices(bpc, partition_verts))
end

function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex)
return factors(bpc, [partition_vertex])
end

function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...)
Expand Down Expand Up @@ -190,7 +196,7 @@ function updated_message(
)
vertex = src(edge)
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bpc, vertex)
state = factors(bpc, vertex)

return message_update_function(
ITensor[incoming_ms; state]; message_update_function_kwargs...
Expand All @@ -203,9 +209,7 @@ function update(
edge::PartitionEdge;
kwargs...,
)
new_m = updated_message(bpc, edge; kwargs...)
bpc = set_message(bpc, edge, new_m)
return bpc
return set_message(bpc, edge, updated_message(bpc, edge; kwargs...))
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function region_scalar(
contract_kwargs=(; sequence="automatic"),
)
incoming_mts = incoming_messages(bp_cache, [pv])
local_state = factor(bp_cache, pv)
local_state = factors(bp_cache, pv)
return contract(vcat(incoming_mts, local_state); contract_kwargs...)[]
end

Expand Down
Loading

0 comments on commit e9e0336

Please sign in to comment.