From 2cfba37caf45196201b85d5406e420e5cd5a4ad3 Mon Sep 17 00:00:00 2001 From: Robert Adam Date: Fri, 9 Feb 2024 16:15:53 +0100 Subject: [PATCH] Add executable to generate tensor network graphs --- CMakeLists.txt | 7 +- .../tensor_network_graphs.cpp | 64 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 examples/tensor_network_graphs/tensor_network_graphs.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 160d964b9..4acea9a18 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -543,8 +543,13 @@ if (BUILD_TESTING) target_link_libraries(${example${i}} SeQuant) endforeach () + set(example12 "tensor_network_graphs") + add_executable(${example12} EXCLUDE_FROM_ALL + examples/${example12}/${example12}.cpp) + target_link_libraries(${example12} SeQuant) + # add tests for running examples - set(lastexample 11) + set(lastexample 12) foreach (i RANGE ${lastexample}) if (TARGET ${example${i}}) add_dependencies(examples-sequant ${example${i}}) diff --git a/examples/tensor_network_graphs/tensor_network_graphs.cpp b/examples/tensor_network_graphs/tensor_network_graphs.cpp new file mode 100644 index 000000000..8da1d1ee3 --- /dev/null +++ b/examples/tensor_network_graphs/tensor_network_graphs.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + using namespace sequant; + +std::wstring from_utf8(std::string_view str) { + std::wstring_convert> converter; + return converter.from_bytes(std::string(str)); +} + +std::optional to_network(const ExprPtr &expr) { + if (expr.is()) { + return TensorNetwork({expr}); + } else if (expr.is()) { + for (const ExprPtr &factor : expr.as().factors()) { + if (!factor.is()) { + return {}; + } + } + + return TensorNetwork(expr.as().factors()); + } else { + return {}; + } +} + +int main(int argc, char **argv) { + + set_locale(); + mbpt::set_default_convention(); + + for (std::size_t i = 1; i < static_cast(argc); ++i) { + std::wstring current = from_utf8(argv[i]); + ExprPtr expr; + try { + expr = parse_expr(current); + } catch (const ParseError &e) { + std::wcout << "Failed to parse expression '" << current << "': " << e.what() << std::endl; + return 1; + } + assert(expr); + + std::optional network = to_network(expr); + if (!network.has_value()) { + std::wcout << "Failed to construct tensor network for input '" << current << "'" << std::endl; + return 2; + } + + TensorNetwork::Graph graph = network->create_graph(); + std::wcout << "Graph for '" << current << "'\n"; + graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels); + } +}