diff --git a/examples/tensor_network_graphs/tensor_network_graphs.cpp b/examples/tensor_network_graphs/tensor_network_graphs.cpp index 62e335d96..3a6fd7062 100644 --- a/examples/tensor_network_graphs/tensor_network_graphs.cpp +++ b/examples/tensor_network_graphs/tensor_network_graphs.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -36,6 +37,22 @@ std::optional to_network(const ExprPtr &expr) { } } +std::optional to_network_v2(const ExprPtr &expr) { + if (expr.is()) { + return TensorNetworkV2({expr}); + } else if (expr.is()) { + for (const ExprPtr &factor : expr.as().factors()) { + if (!factor.is()) { + return {}; + } + } + + return TensorNetworkV2(expr.as().factors()); + } else { + return {}; + } +} + void print_help() { std::wcout << "Helper to generate dot (GraphViz) representations of tensor " "network graphs.\n"; @@ -44,6 +61,7 @@ void print_help() { << " [options] [ [... [] ] ]\n"; std::wcout << "Options:\n"; std::wcout << " --help Shows this help message\n"; + std::wcout << " --v2 Use TensorNetworkV2\n"; std::wcout << " --no-named Treat all indices as unnamed (even if they are " "external)\n"; } @@ -55,6 +73,7 @@ int main(int argc, char **argv) { BraKetSymmetry::conjugate, SPBasis::spinorbital)); bool use_named_indices = true; + bool use_tnv2 = false; const TensorNetwork::named_indices_t empty_named_indices; if (argc <= 1) { @@ -70,6 +89,9 @@ int main(int argc, char **argv) { } else if (current == L"--no-named") { use_named_indices = false; continue; + } else if (current == L"--v2") { + use_tnv2 = true; + continue; } ExprPtr expr; @@ -82,16 +104,30 @@ int main(int argc, char **argv) { } assert(expr); - std::optional network = to_network(expr); - if (!network.has_value()) { - std::wcout << "Failed to construct tensor network for input '" << current - << "'" << std::endl; - return 2; - } + if (!use_tnv2) { + std::optional network = to_network(expr); + if (!network.has_value()) { + std::wcout << "Failed to construct tensor network for input '" + << current << "'" << std::endl; + return 2; + } + + auto [graph, vlabels, vcolors, vtypes] = network->make_bliss_graph( + use_named_indices ? nullptr : &empty_named_indices); + std::wcout << "Graph for '" << current << "'\n"; + graph->write_dot(std::wcout, vlabels); + } else { + std::optional network = to_network_v2(expr); + if (!network.has_value()) { + std::wcout << "Failed to construct tensor network for input '" + << current << "'" << std::endl; + return 2; + } - TensorNetwork::Graph graph = network->create_graph( - use_named_indices ? nullptr : &empty_named_indices); - std::wcout << "Graph for '" << current << "'\n"; - graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels); + TensorNetworkV2::Graph graph = network->create_graph( + use_named_indices ? nullptr : &empty_named_indices); + std::wcout << "Graph for '" << current << "'\n"; + graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels); + } } }