Skip to content

Commit

Permalink
tensor_network_graphs can use V1 and V2 TensorNetworks
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Nov 19, 2024
1 parent 1f80f51 commit 5e13a2c
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions examples/tensor_network_graphs/tensor_network_graphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <SeQuant/core/runtime.hpp>
#include <SeQuant/core/tensor.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/tensor_network_v2.hpp>
#include <SeQuant/domain/mbpt/context.hpp>
#include <SeQuant/domain/mbpt/convention.hpp>

Expand Down Expand Up @@ -36,6 +37,22 @@ std::optional<TensorNetwork> to_network(const ExprPtr &expr) {
}
}

std::optional<TensorNetworkV2> to_network_v2(const ExprPtr &expr) {
if (expr.is<Tensor>()) {
return TensorNetworkV2({expr});
} else if (expr.is<Product>()) {
for (const ExprPtr &factor : expr.as<Product>().factors()) {
if (!factor.is<Tensor>()) {
return {};
}
}

return TensorNetworkV2(expr.as<Product>().factors());
} else {
return {};
}
}

void print_help() {
std::wcout << "Helper to generate dot (GraphViz) representations of tensor "
"network graphs.\n";
Expand All @@ -44,6 +61,7 @@ void print_help() {
<< " <exe> [options] <network 1> [<network 2> [... [<network N>] ] ]\n";
std::wcout << "Options:\n";
std::wcout << " --help Shows this help message\n";
std::wcout << " --v2 Use TensorNetworkV2\n";
std::wcout << " --no-named Treat all indices as unnamed (even if they are "
"external)\n";
}
Expand All @@ -55,6 +73,7 @@ int main(int argc, char **argv) {
BraKetSymmetry::conjugate, SPBasis::spinorbital));

bool use_named_indices = true;
bool use_tnv2 = false;
const TensorNetwork::named_indices_t empty_named_indices;

if (argc <= 1) {
Expand All @@ -70,6 +89,9 @@ int main(int argc, char **argv) {
} else if (current == L"--no-named") {
use_named_indices = false;
continue;
} else if (current == L"--v2") {
use_tnv2 = true;
continue;
}

ExprPtr expr;
Expand All @@ -82,16 +104,30 @@ int main(int argc, char **argv) {
}
assert(expr);

std::optional<TensorNetwork> network = to_network(expr);
if (!network.has_value()) {
std::wcout << "Failed to construct tensor network for input '" << current
<< "'" << std::endl;
return 2;
}
if (!use_tnv2) {
std::optional<TensorNetwork> network = to_network(expr);
if (!network.has_value()) {
std::wcout << "Failed to construct tensor network for input '"
<< current << "'" << std::endl;
return 2;
}

auto [graph, vlabels, vcolors, vtypes] = network->make_bliss_graph(
use_named_indices ? nullptr : &empty_named_indices);
std::wcout << "Graph for '" << current << "'\n";
graph->write_dot(std::wcout, vlabels);
} else {
std::optional<TensorNetworkV2> network = to_network_v2(expr);
if (!network.has_value()) {
std::wcout << "Failed to construct tensor network for input '"
<< current << "'" << std::endl;
return 2;
}

TensorNetwork::Graph graph = network->create_graph(
use_named_indices ? nullptr : &empty_named_indices);
std::wcout << "Graph for '" << current << "'\n";
graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels);
TensorNetworkV2::Graph graph = network->create_graph(
use_named_indices ? nullptr : &empty_named_indices);
std::wcout << "Graph for '" << current << "'\n";
graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels);
}
}
}

0 comments on commit 5e13a2c

Please sign in to comment.