Skip to content

Commit

Permalink
Make old graph repr use colors of new one
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzmbrzl committed Jan 16, 2025
1 parent 0931c3d commit 04aeccc
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 282 deletions.
100 changes: 35 additions & 65 deletions SeQuant/core/tensor_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/tag.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/vertex_painter.hpp>
#include <SeQuant/core/wstring.hpp>

#include <algorithm>
Expand Down Expand Up @@ -425,6 +426,8 @@ TensorNetwork::make_bliss_graph(
const auto &named_indices =
named_indices_ptr == nullptr ? this->ext_indices() : *named_indices_ptr;

VertexPainter colorizer(named_indices);

// results
std::shared_ptr<bliss::Graph> graph;
std::vector<std::wstring> vertex_labels(
Expand All @@ -434,25 +437,6 @@ TensorNetwork::make_bliss_graph(
std::vector<VertexType> vertex_type(
edges_.size()); // the size will be updated

// N.B. Colors [0, 3 max rank + named_indices.size()) are reserved:
// 0 - the bra vertex (for particle 0, if bra is nonsymm, or for the entire
// bra, if (anti)symm) 1 - the bra vertex for particle 1, if bra is nonsymm
// ...
// max_rank - the ket vertex (for particle 0, if particle-asymmetric, or for
// the entire ket, if particle-symmetric) max_rank+1 - the ket vertex for
// particle 1, if particle-asymmetric
// ...
// 2 max_rank - the aux index
// ...
// 3 max_rank - first named index
// 3 max_rank + 1 - second named index
// ...
// N.B. For braket-symmetric tensors the ket vertices use the same indices as
// the bra vertices
auto nonreserved_color = [&named_indices](size_t color) -> bool {
return color >= 3 * max_rank + named_indices.size();
};

// compute # of vertices
size_t nv = 0;
size_t index_cnt = 0;
Expand All @@ -470,17 +454,8 @@ TensorNetwork::make_bliss_graph(
++nv; // each index is a vertex
vertex_labels.at(index_cnt) = idx.to_latex();
vertex_type.at(index_cnt) = VertexType::Index;
// assign color: named indices use reserved colors
const auto named_index_it = named_indices.find(idx);
if (named_index_it ==
named_indices.end()) { // anonymous index? use Index::color
const auto idx_color = idx.color();
assert(nonreserved_color(idx_color));
vertex_color.at(index_cnt) = idx_color;
} else {
const auto named_index_rank = named_index_it - named_indices.begin();
vertex_color.at(index_cnt) = 3 * max_rank + named_index_rank;
}
vertex_color.at(index_cnt) = colorizer(idx);

// each symmetric proto index bundle will have a vertex ...
// for now only store the unique protoindex bundles in
// symmetric_protoindex_bundles, then commit their data to
Expand All @@ -507,11 +482,10 @@ TensorNetwork::make_bliss_graph(
spbundle_label += L"}";
vertex_labels.push_back(spbundle_label);
vertex_type.push_back(VertexType::SPBundle);
const auto idx_proto_indices_color = Index::proto_indices_color(bundle);
assert(nonreserved_color(idx_proto_indices_color));
vertex_color.push_back(idx_proto_indices_color);
vertex_color.push_back(colorizer(bundle));
spbundle_cnt++;
});

// now account for vertex representation of tensors
size_t tensor_cnt = 0;
// this will map to tensor index to the first (core) vertex in its
Expand All @@ -524,10 +498,8 @@ TensorNetwork::make_bliss_graph(
const auto tlabel = label(*t);
vertex_labels.emplace_back(tlabel);
vertex_type.emplace_back(VertexType::TensorCore);
const auto t_color = hash::value(tlabel);
static_assert(sizeof(t_color) == sizeof(unsigned long int));
assert(nonreserved_color(t_color));
vertex_color.push_back(t_color);
vertex_color.push_back(colorizer(*t));

// symmetric/antisymmetric tensors are represented by 3 more vertices:
// - bra
// - ket
Expand All @@ -539,18 +511,24 @@ TensorNetwork::make_bliss_graph(
std::wstring(L"bra") + to_wstring(bra_rank(tref)) +
((symmetry(tref) == Symmetry::antisymm) ? L"a" : L"s"));
vertex_type.push_back(VertexType::TensorBra);
vertex_color.push_back(0);
vertex_color.push_back(colorizer(BraGroup{0}));
vertex_labels.push_back(
std::wstring(L"ket") + to_wstring(ket_rank(tref)) +
((symmetry(tref) == Symmetry::antisymm) ? L"a" : L"s"));
vertex_type.push_back(VertexType::TensorKet);
vertex_color.push_back(
braket_symmetry(tref) == BraKetSymmetry::symm ? 0 : max_rank);
if (braket_symmetry(tref) == BraKetSymmetry::symm) {
// Use BraGroup for kets as well as they are supposed to be
// indistinguishable
vertex_color.push_back(colorizer(BraGroup{0}));
} else {
vertex_color.push_back(colorizer(KetGroup{0}));
}
vertex_labels.push_back(
std::wstring(L"bk") +
((symmetry(tref) == Symmetry::antisymm) ? L"a" : L"s"));
vertex_type.push_back(VertexType::Particle);
vertex_color.push_back(t_color);
// Color bk node in same color as tensor core
vertex_color.push_back(colorizer(tref));
}
// nonsymmetric tensors are represented by 3*rank more vertices (with rank =
// max(bra_rank(),ket_rank())
Expand All @@ -562,18 +540,24 @@ TensorNetwork::make_bliss_graph(
auto pstr = to_wstring(p + 1);
vertex_labels.push_back(std::wstring(L"bra") + pstr);
vertex_type.push_back(VertexType::TensorBra);
const bool t_is_particle_symmetric =
const bool distinguishable_particles =
particle_symmetry(tref) == ParticleSymmetry::nonsymm;
const auto bra_color = t_is_particle_symmetric ? p : 0;
vertex_color.push_back(bra_color);
vertex_color.push_back(
colorizer(BraGroup{distinguishable_particles ? p : 0}));
vertex_labels.push_back(std::wstring(L"ket") + pstr);
vertex_type.push_back(VertexType::TensorKet);
vertex_color.push_back(braket_symmetry(tref) == BraKetSymmetry::symm
? bra_color
: bra_color + max_rank);
if (braket_symmetry(tref) == BraKetSymmetry::symm) {
// Use BraGroup for kets as well as they are supposed to be
// indistinguishable
vertex_color.push_back(
colorizer(BraGroup{distinguishable_particles ? p : 0}));
} else {
vertex_color.push_back(
colorizer(KetGroup{distinguishable_particles ? p : 0}));
}
vertex_labels.push_back(std::wstring(L"bk") + pstr);
vertex_type.push_back(VertexType::Particle);
vertex_color.push_back(t_color);
vertex_color.push_back(colorizer(tref));
}
}
// aux indices currently do not support any symmetry
Expand All @@ -583,8 +567,7 @@ TensorNetwork::make_bliss_graph(
auto pstr = to_wstring(p + 1);
vertex_labels.push_back(std::wstring(L"aux") + pstr);
vertex_type.push_back(VertexType::TensorAux);
const auto color = 2 * max_rank + p;
vertex_color.push_back(color);
vertex_color.push_back(colorizer(AuxGroup{p}));
}

++tensor_cnt;
Expand Down Expand Up @@ -659,21 +642,8 @@ TensorNetwork::make_bliss_graph(
++tensor_cnt;
});

// compress vertex colors to 32 bits, as required by Bliss, by hashing
size_t v_cnt = 0;
for (auto &&color : vertex_color) {
auto hash6432shift = [](size_t key) {
static_assert(sizeof(key) == 8);
key = (~key) + (key << 18); // key = (key << 18) - key - 1;
key = key ^ (key >> 31);
key = key * 21; // key = (key + (key << 2)) + (key << 4);
key = key ^ (key >> 11);
key = key + (key << 6);
key = key ^ (key >> 22);
return static_cast<int>(key);
};
graph->change_color(v_cnt, hash6432shift(color));
++v_cnt;
for (const auto [vertex, color] : ranges::views::enumerate(vertex_color)) {
graph->change_color(vertex, color);
}

return {graph, vertex_labels, vertex_color, vertex_type};
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/test_canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ TEST_CASE("Canonicalizer", "[algorithms]") {
canonicalize(input);
REQUIRE_THAT(
input,
SimplifiesTo("S{a1,a3;i1,i2} f{a2;i3} t{i3;a3} t{i1,i2;a1,a2}"));
SimplifiesTo("S{a1,a2;i1,i2} f{a3;i3} t{i3;a2} t{i1,i2;a1,a3}"));
}
{
auto input =
Expand Down Expand Up @@ -149,7 +149,7 @@ TEST_CASE("Canonicalizer", "[algorithms]") {
canonicalize(input1);
REQUIRE_THAT(
input1,
SimplifiesTo("S{a1,a3;i1,i2} f{a2;i3} f⁺{a1,a3;i1,i3} t{i2;a2}"));
SimplifiesTo("S{a1,a2;i1,i3} f{a3;i2} f⁺{a1,a2;i1,i2} t{i3;a3}"));
auto input2 =
ex<Tensor>(L"S", bra{L"a_1", L"a_2"}, ket{L"i_1", L"i_2"},
Symmetry::nonsymm) *
Expand All @@ -160,7 +160,7 @@ TEST_CASE("Canonicalizer", "[algorithms]") {
REQUIRE_THAT(
input2,
SimplifiesTo(
"1/2 w S{a1,a3;i1,i2} f{a2;i3} f⁺{a1,a3;i1,i3} t{i2;a2}"));
"1/2 w S{a1,a2;i1,i3} f{a3;i2} f⁺{a1,a2;i1,i2} t{i3;a3}"));
}
}
{
Expand Down Expand Up @@ -190,8 +190,9 @@ TEST_CASE("Canonicalizer", "[algorithms]") {
Symmetry::nonsymm) *
ex<Tensor>(L"t", bra{L"p_3"}, ket{L"p_1"}, Symmetry::nonsymm) *
ex<Tensor>(L"t", bra{L"p_4"}, ket{L"p_2"}, Symmetry::nonsymm);
simplify(input);
canonicalize(input);
REQUIRE_THAT(input, SimplifiesTo("g{p2,p3;p1,p4} t{p1;p2} t{p4;p3}"));
REQUIRE_THAT(input, SimplifiesTo("g{p3,p4;p1,p2} t{p1;p3} t{p2;p4}"));
}

// CASE 2: Symmetric tensors
Expand Down Expand Up @@ -268,7 +269,7 @@ TEST_CASE("Canonicalizer", "[algorithms]") {
canonicalize(input);
REQUIRE(input->size() == 1);
REQUIRE_THAT(input,
SimplifiesTo("g{i3,i4;a3,i1} t{a2;i3} t{a1,a3;i4,i2}"));
SimplifiesTo("g{i3,i4;i1,a3} t{a2;i4} t{a1,a3;i3,i2}"));
}

{ // Case 5: CCSDT R3: S3 * F * T3
Expand Down Expand Up @@ -369,10 +370,10 @@ TEST_CASE("Canonicalizer", "[algorithms]") {
Symmetry::nonsymm);

canonicalize(input);
REQUIRE(input->size() == 1);
simplify(input);
REQUIRE_THAT(
input,
SimplifiesTo("t{a2;i3} t{a1,a3;i4,i2} B{i3;a3;p5} B{i4;i1;p5}"));
SimplifiesTo("t{a2;i4} t{a1,a3;i3,i2} B{i3;i1;p5} B{i4;a3;p5}"));
}
}
}
Loading

0 comments on commit 04aeccc

Please sign in to comment.