diff --git a/SeQuant/core/tensor_network.hpp b/SeQuant/core/tensor_network.hpp index b73d10b17..70e84911f 100644 --- a/SeQuant/core/tensor_network.hpp +++ b/SeQuant/core/tensor_network.hpp @@ -13,6 +13,7 @@ #include #include #include +#include // forward declarations namespace bliss { @@ -145,20 +146,29 @@ class TensorNetwork { std::size_t vertex_to_tensor_idx(std::size_t vertex) const; }; - /// @throw std::logic_error if exprptr_range contains a non-tensor - /// @note uses RTTI - template - TensorNetwork(const ExprPtrRange &exprptr_range) { - for (const auto &ex : exprptr_range) { - ExprPtr clone = ex.clone(); - auto t = std::dynamic_pointer_cast(clone); - if (t) { - tensors_.emplace_back(std::move(t)); - } else { - throw std::logic_error( - "TensorNetwork::TensorNetwork: non-tensors in the given expression " - "range"); + TensorNetwork(const Expr &expr) { + if (expr.size() > 0) { + for (const ExprPtr &subexpr : expr) { + add_expr(*subexpr); } + } else { + add_expr(expr); + } + + init_edges(); + } + + TensorNetwork(const ExprPtr &expr) : TensorNetwork(*expr) {} + + template < + typename ExprPtrRange, + typename = std::enable_if_t && + !std::is_base_of_v>> + TensorNetwork(const ExprPtrRange &exprptr_range) { + static_assert( + std::is_base_of_v); + for (const ExprPtr ¤t : exprptr_range) { + add_expr(*current); } init_edges(); @@ -268,6 +278,18 @@ class TensorNetwork { ExprPtr do_individual_canonicalization( const TensorCanonicalizer &canonicalizer); + + void add_expr(const Expr &expr) { + ExprPtr clone = expr.clone(); + + auto tensor_ptr = std::dynamic_pointer_cast(clone); + if (!tensor_ptr) { + throw std::invalid_argument( + "TensorNetwork::TensorNetwork: tried to add non-tensor to network"); + } + + tensors_.push_back(std::move(tensor_ptr)); + } }; template