Skip to content

Commit

Permalink
Add executable to generate tensor network graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzmbrzl committed Mar 26, 2024
1 parent a316a41 commit 2cfba37
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,13 @@ if (BUILD_TESTING)
target_link_libraries(${example${i}} SeQuant)
endforeach ()

set(example12 "tensor_network_graphs")
add_executable(${example12} EXCLUDE_FROM_ALL
examples/${example12}/${example12}.cpp)
target_link_libraries(${example12} SeQuant)

# add tests for running examples
set(lastexample 11)
set(lastexample 12)
foreach (i RANGE ${lastexample})
if (TARGET ${example${i}})
add_dependencies(examples-sequant ${example${i}})
Expand Down
64 changes: 64 additions & 0 deletions examples/tensor_network_graphs/tensor_network_graphs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <SeQuant/core/runtime.hpp>
#include <SeQuant/core/bliss.hpp>
#include <SeQuant/core/parse_expr.hpp>
#include <SeQuant/core/tensor.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/domain/mbpt/convention.hpp>

#include <string>
#include <cassert>
#include <codecvt>
#include <locale>
#include <iostream>
#include <optional>

using namespace sequant;

std::wstring from_utf8(std::string_view str) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(std::string(str));
}

std::optional<TensorNetwork> to_network(const ExprPtr &expr) {
if (expr.is<Tensor>()) {
return TensorNetwork({expr});
} else if (expr.is<Product>()) {
for (const ExprPtr &factor : expr.as<Product>().factors()) {
if (!factor.is<Tensor>()) {
return {};
}
}

return TensorNetwork(expr.as<Product>().factors());
} else {
return {};
}
}

int main(int argc, char **argv) {

set_locale();
mbpt::set_default_convention();

for (std::size_t i = 1; i < static_cast<std::size_t>(argc); ++i) {
std::wstring current = from_utf8(argv[i]);
ExprPtr expr;
try {
expr = parse_expr(current);
} catch (const ParseError &e) {
std::wcout << "Failed to parse expression '" << current << "': " << e.what() << std::endl;
return 1;
}
assert(expr);

std::optional<TensorNetwork> network = to_network(expr);
if (!network.has_value()) {
std::wcout << "Failed to construct tensor network for input '" << current << "'" << std::endl;
return 2;
}

TensorNetwork::Graph graph = network->create_graph();
std::wcout << "Graph for '" << current << "'\n";
graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels);
}
}

0 comments on commit 2cfba37

Please sign in to comment.