Skip to content

Commit

Permalink
Make TensorNetworks constructible from single tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzmbrzl committed Feb 15, 2024
1 parent dfe887c commit 307b716
Showing 1 changed file with 35 additions and 13 deletions.
48 changes: 35 additions & 13 deletions SeQuant/core/tensor_network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cassert>
#include <iosfwd>
#include <memory>
#include <type_traits>

// forward declarations
namespace bliss {
Expand Down Expand Up @@ -145,20 +146,29 @@ class TensorNetwork {
std::size_t vertex_to_tensor_idx(std::size_t vertex) const;
};

/// @throw std::logic_error if exprptr_range contains a non-tensor
/// @note uses RTTI
template <typename ExprPtrRange>
TensorNetwork(const ExprPtrRange &exprptr_range) {
for (const auto &ex : exprptr_range) {
ExprPtr clone = ex.clone();
auto t = std::dynamic_pointer_cast<AbstractTensor>(clone);
if (t) {
tensors_.emplace_back(std::move(t));
} else {
throw std::logic_error(
"TensorNetwork::TensorNetwork: non-tensors in the given expression "
"range");
TensorNetwork(const Expr &expr) {
if (expr.size() > 0) {
for (const ExprPtr &subexpr : expr) {
add_expr(*subexpr);
}
} else {
add_expr(expr);
}

init_edges();
}

TensorNetwork(const ExprPtr &expr) : TensorNetwork(*expr) {}

template <
typename ExprPtrRange,
typename = std::enable_if_t<!std::is_base_of_v<ExprPtr, ExprPtrRange> &&
!std::is_base_of_v<Expr, ExprPtrRange>>>
TensorNetwork(const ExprPtrRange &exprptr_range) {
static_assert(
std::is_base_of_v<ExprPtr, typename ExprPtrRange::value_type>);
for (const ExprPtr &current : exprptr_range) {
add_expr(*current);
}

init_edges();
Expand Down Expand Up @@ -268,6 +278,18 @@ class TensorNetwork {

ExprPtr do_individual_canonicalization(
const TensorCanonicalizer &canonicalizer);

void add_expr(const Expr &expr) {
ExprPtr clone = expr.clone();

auto tensor_ptr = std::dynamic_pointer_cast<AbstractTensor>(clone);
if (!tensor_ptr) {
throw std::invalid_argument(
"TensorNetwork::TensorNetwork: tried to add non-tensor to network");
}

tensors_.push_back(std::move(tensor_ptr));
}
};

template <typename CharT, typename Traits>
Expand Down

0 comments on commit 307b716

Please sign in to comment.